[Mlir-commits] [mlir] b006902 - [mlir] Fold trivial subtensor / subtensor_insert ops.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Feb 18 13:41:54 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-18T21:34:55Z
New Revision: b006902b2dfac792e8ade73798ca1b216654faf7

URL: https://github.com/llvm/llvm-project/commit/b006902b2dfac792e8ade73798ca1b216654faf7
DIFF: https://github.com/llvm/llvm-project/commit/b006902b2dfac792e8ade73798ca1b216654faf7.diff

LOG: [mlir] Fold trivial subtensor / subtensor_insert ops.

Static subtensor / subtensor_insert of the same size as the source / destination tensor and root @[0..0] with strides [1..1] are folded away.

Differential revision: https://reviews.llvm.org/D96991

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index a7142d298b66..9a253f8e814a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -364,6 +364,12 @@ bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
 /// comparison predicates.
 bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
                        const APFloat &rhs);
+
+/// Return true if ofr1 and ofr2 are the same integer constant attribute values
+/// or the same SSA value.
+/// Ignore integer bitwitdh and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType have no bitwidth.
+bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
 } // end namespace mlir
 
 #endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 29863c82c502..64279c8fce3c 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2928,6 +2928,7 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
   }];
 
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -3026,6 +3027,7 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
     /// and `strides` operands.
     static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 52b41ca305d1..e916de4c0658 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -59,6 +59,27 @@ static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
 }
 
+/// Return true if ofr1 and ofr2 are the same integer constant attribute values
+/// or the same SSA value.
+/// Ignore integer bitwitdh and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType have no bitwidth.
+bool mlir::isEqualConstantIntOrValue(OpFoldResult op1, OpFoldResult op2) {
+  auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
+    Attribute attr = ofr.dyn_cast<Attribute>();
+    // Note: isa+cast-like pattern allows writing the condition below as 1 line.
+    if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
+      attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
+    if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
+      return intAttr.getValue().getSExtValue();
+    return llvm::None;
+  };
+  auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
+  if (cst1 && cst2 && *cst1 == *cst2)
+    return true;
+  auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
+  return v1 && v2 && v1 == v2;
+}
+
 //===----------------------------------------------------------------------===//
 // StandardOpsDialect Interfaces
 //===----------------------------------------------------------------------===//
@@ -3557,6 +3578,34 @@ void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
           context);
 }
 
+//
+static LogicalResult
+foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
+                                           ShapedType shapedType) {
+  OpBuilder b(op.getContext());
+  for (OpFoldResult ofr : op.getMixedOffsets())
+    if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(0)))
+      return failure();
+  // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
+  // is appropriate.
+  auto shape = shapedType.getShape();
+  for (auto it : llvm::zip(op.getMixedSizes(), shape))
+    if (!isEqualConstantIntOrValue(std::get<0>(it),
+                                   b.getIndexAttr(std::get<1>(it))))
+      return failure();
+  for (OpFoldResult ofr : op.getMixedStrides())
+    if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(1)))
+      return failure();
+  return success();
+}
+
+OpFoldResult SubTensorOp::fold(ArrayRef<Attribute>) {
+  if (getSourceType() == getType() &&
+      succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
+    return this->source();
+  return OpFoldResult();
+}
+
 //===----------------------------------------------------------------------===//
 // SubTensorInsertOp
 //===----------------------------------------------------------------------===//
@@ -3597,6 +3646,13 @@ void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
   build(b, result, source, dest, offsetValues, sizeValues, strideValues);
 }
 
+OpFoldResult SubTensorInsertOp::fold(ArrayRef<Attribute>) {
+  if (getSourceType() == getType() &&
+      succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
+    return this->source();
+  return OpFoldResult();
+}
+
 //===----------------------------------------------------------------------===//
 // TensorLoadOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index c864af8f5747..9247152e8677 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -157,3 +157,22 @@ func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
     memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
   return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
 }
+
+// CHECK-LABEL: func @trivial_subtensor
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
+//   CHECK-NOT:   subtensor
+//       CHECK:   return %[[ARG0]] :  tensor<4x6x16x32xi8>
+func @trivial_subtensor(%arg0 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
+  %0 = subtensor %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<4x6x16x32xi8>
+  return %0 : tensor<4x6x16x32xi8>
+}
+
+// CHECK-LABEL: func @trivial_subtensor_insert
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
+//   CHECK-NOT:   subtensor
+//       CHECK:   return %[[ARG0]] :  tensor<4x6x16x32xi8>
+func @trivial_subtensor_insert(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
+  %0 = subtensor_insert %arg0 into %arg1[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<4x6x16x32xi8>
+  return %0 : tensor<4x6x16x32xi8>
+}
+


        


More information about the Mlir-commits mailing list