[Mlir-commits] [mlir] bcc31d6 - [mlir][tensor] Implement DestinationStyleOpInterface for tensor.insert/insert_slice

Matthias Springer llvmlistbot at llvm.org
Thu Oct 27 01:34:07 PDT 2022


Author: Matthias Springer
Date: 2022-10-27T10:33:58+02:00
New Revision: bcc31d694f6e15f66c0caab49e4b65fb14e8612f

URL: https://github.com/llvm/llvm-project/commit/bcc31d694f6e15f66c0caab49e4b65fb14e8612f
DIFF: https://github.com/llvm/llvm-project/commit/bcc31d694f6e15f66c0caab49e4b65fb14e8612f.diff

LOG: [mlir][tensor] Implement DestinationStyleOpInterface for tensor.insert/insert_slice

Also allow unranked tensors/memrefs with destination style op outputs.

This allows for a simpler implementation of the BufferizableOpInterface (in a subsequent commit).

Differential Revision: https://reviews.llvm.org/D136346

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
index 1e07378bdb453..77be053ce0e17 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
+++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h
@@ -16,6 +16,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"

diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 86f3c06fbaccb..e2451f9429593 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -12,6 +12,7 @@
 include "mlir/Dialect/Tensor/IR/TensorBase.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/ParallelCombiningOpInterface.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -679,6 +680,7 @@ def Tensor_GenerateOp : Tensor_Op<"generate", [
 
 def Tensor_InsertOp : Tensor_Op<"insert", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DestinationStyleOpInterface,
     Pure,
     TypesMatchWith<"result type matches type of dest",
                    "dest", "result",
@@ -720,6 +722,12 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
       build($_builder, $_state, resType, scalar, dest, indices);
     }]>];
 
+  let extraClassDeclaration = [{
+    std::pair<int64_t, int64_t> getOutputsPositionRange() {
+      return {1, 2};  // `dest` operand
+    }
+  }];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 }
@@ -732,6 +740,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
     AttrSizedOperandSegments, 
+    DestinationStyleOpInterface,
     Pure, 
     OffsetSizeAndStrideOpInterface,
     TypesMatchWith<"expected result type to match dest type",
@@ -858,6 +867,10 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
     /// Return the number of leading operands before the `offsets`, `sizes` and
     /// and `strides` operands.
     static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
+
+    std::pair<int64_t, int64_t> getOutputsPositionRange() {
+      return {1, 2};  // `dest` operand
+    }
   }];
 
   let hasCanonicalizer = 1;

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 3f1c993a5de47..ad8500245ea53 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -5132,6 +5132,7 @@ td_library(
     deps = [
         ":CastInterfacesTdFiles",
         ":ControlFlowInterfacesTdFiles",
+        ":DestinationStyleOpInterfaceTdFiles",
         ":InferTypeOpInterfaceTdFiles",
         ":OpBaseTdFiles",
         ":ParallelCombiningOpInterfaceTdFiles",


        


More information about the Mlir-commits mailing list