[Mlir-commits] [mlir] c897a7f - [mlir][Standard] Add canonicalizer for dynamic_tensor_from_elements

Stephan Herhut llvmlistbot at llvm.org
Tue Sep 15 06:38:34 PDT 2020


Author: Stephan Herhut
Date: 2020-09-15T15:38:14+02:00
New Revision: c897a7fb3e2a5c200a3e87a92886eab20d9f7fc7

URL: https://github.com/llvm/llvm-project/commit/c897a7fb3e2a5c200a3e87a92886eab20d9f7fc7
DIFF: https://github.com/llvm/llvm-project/commit/c897a7fb3e2a5c200a3e87a92886eab20d9f7fc7.diff

LOG: [mlir][Standard] Add canonicalizer for dynamic_tensor_from_elements

This add canonicalizer for

- extracting an element from a dynamic_tensor_from_elements
- propagating constant operands to the type of dynamic_tensor_from_elements

Differential Revision: https://reviews.llvm.org/D87525

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 4d0cf76ec9d8..b0aa9b9e3c76 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1511,6 +1511,8 @@ def DynamicTensorFromElementsOp : Std_Op<"dynamic_tensor_from_elements",
               "ValueRange dynamicExtents, "
               "function_ref<void(OpBuilder &, Location, ValueRange)>">,
   ];
+
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index c77bc12cca33..0c86c87384d3 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/CommonFolders.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Matchers.h"
@@ -1730,6 +1731,101 @@ void DynamicTensorFromElementsOp::build(
   bodyBuilder(b, result.location, bodyBlock->getArguments());
 }
 
+namespace {
+
+/// Canonicalizes dynamic_tensor_from_elements operations with a constant
+/// operand into the equivalent operation with the operand expressed in the
+/// result type, instead. We also insert a type cast to make sure that the
+/// resulting IR is still well-typed.
+struct StaticDynamicTensorFromElements
+    : public OpRewritePattern<DynamicTensorFromElementsOp> {
+  using OpRewritePattern<DynamicTensorFromElementsOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DynamicTensorFromElementsOp tensorFromElements,
+                                PatternRewriter &rewriter) const final {
+    auto resultType =
+        tensorFromElements.getResult().getType().cast<RankedTensorType>();
+
+    if (resultType.hasStaticShape())
+      return failure();
+
+    SmallVector<Value, 4> newOperands;
+    SmallVector<int64_t, 4> newShape;
+    auto operandsIt = tensorFromElements.dynamicExtents().begin();
+
+    for (int64_t dim : resultType.getShape()) {
+      if (dim != RankedTensorType::kDynamicSize) {
+        newShape.push_back(dim);
+        continue;
+      }
+      APInt index;
+      if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
+        newShape.push_back(RankedTensorType::kDynamicSize);
+        newOperands.push_back(*operandsIt++);
+        continue;
+      }
+      newShape.push_back(index.getSExtValue());
+      operandsIt++;
+    }
+
+    if (newOperands.size() == tensorFromElements.dynamicExtents().size())
+      return failure();
+
+    auto loc = tensorFromElements.getLoc();
+    auto newOp = rewriter.create<DynamicTensorFromElementsOp>(
+        loc, RankedTensorType::get(newShape, resultType.getElementType()),
+        newOperands);
+    rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(),
+                                newOp.body().begin());
+    rewriter.replaceOpWithNewOp<TensorCastOp>(tensorFromElements, resultType,
+                                              newOp);
+    return success();
+  }
+};
+
+/// Canonicalizes the pattern of the form
+///
+/// %tensor = dynamic_tensor_from_elements %x {
+///   ^bb0(%arg0: index):  // no predecessors
+///   <computation>
+///   yield %1 : index
+/// } : tensor<?xindex>
+/// %extracted_element = extract_element %tensor[%c0] : tensor<?xi32>
+///
+/// to just <computation> with %arg0 replaced by %c0. We only do this if the
+/// dynamic_tensor_from_elements operation has no side-effects.
+struct ExtractElementFromDynamicTensorFromElements
+    : public OpRewritePattern<ExtractElementOp> {
+  using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractElementOp extract,
+                                PatternRewriter &rewriter) const final {
+    auto tensorFromElements =
+        extract.aggregate().getDefiningOp<DynamicTensorFromElementsOp>();
+    if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
+      return failure();
+
+    BlockAndValueMapping mapping;
+    Block *body = tensorFromElements.getBody();
+    mapping.map(body->getArguments(), extract.indices());
+    for (auto &op : body->without_terminator())
+      rewriter.clone(op, mapping);
+
+    auto yield = cast<YieldOp>(body->getTerminator());
+
+    rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value()));
+    return success();
+  }
+};
+
+} // namespace
+
+void DynamicTensorFromElementsOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<ExtractElementFromDynamicTensorFromElements,
+                 StaticDynamicTensorFromElements>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ExtractElementOp
 //===----------------------------------------------------------------------===//
@@ -1807,16 +1903,16 @@ struct ExtractElementFromTensorFromElements
     if (extract.indices().size() != 1)
       return failure();
 
-    auto tensor_from_elements = dyn_cast_or_null<TensorFromElementsOp>(
+    auto tensorFromElements = dyn_cast_or_null<TensorFromElementsOp>(
         extract.aggregate().getDefiningOp());
-    if (tensor_from_elements == nullptr)
+    if (tensorFromElements == nullptr)
       return failure();
 
     APInt index;
     if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index)))
       return failure();
     rewriter.replaceOp(extract,
-                       tensor_from_elements.getOperand(index.getZExtValue()));
+                       tensorFromElements.getOperand(index.getZExtValue()));
     return success();
   }
 };

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 76fe82588be3..320418545893 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -986,3 +986,79 @@ func @extract_element_from_tensor_from_elements(%element : index) -> index {
   // CHECK: [[ARG]] : index
   return %extracted_element : index
 }
+
+// -----
+
+// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements
+// CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
+func @extract_element_from_dynamic_tensor_from_elements(%idx: index, %tensor: tensor<*xf32>) -> index {
+  %size = rank %tensor : tensor<*xf32>
+  // CHECK-NEXT: %[[RES:.*]] = dim %[[TENSOR]], %[[IDX]]
+  %0 = dynamic_tensor_from_elements %size {
+    ^bb0(%arg0: index):
+    %1 = dim %tensor, %arg0 : tensor<*xf32>
+    yield %1 : index
+  } : tensor<?xindex>
+  %1 = extract_element %0[%idx] : tensor<?xindex>
+  // CHECK-NEXT: return %[[RES]]
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements_2d
+// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
+func @extract_element_from_dynamic_tensor_from_elements_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
+  %size = rank %tensor : tensor<*xf32>
+  // CHECK-NEXT: %[[DIM0:.*]] = dim %[[TENSOR]], %[[IDX0]]
+  // CHECK-NEXT: %[[DIM1:.*]] = dim %[[TENSOR]], %[[IDX1]]
+  // CHECK-NEXT: %[[RES:.*]] = addi %[[DIM0]], %[[DIM1]]
+  %0 = dynamic_tensor_from_elements %size, %size {
+    ^bb0(%arg0: index, %arg1: index):
+    %1 = dim %tensor, %arg0 : tensor<*xf32>
+    %2 = dim %tensor, %arg1 : tensor<*xf32>
+    %3 = addi %1, %2 : index
+    yield %3 : index
+  } : tensor<?x?xindex>
+  %4 = extract_element %0[%idx0, %idx1] : tensor<?x?xindex>
+  // CHECK-NEXT: return %[[RES]]
+  return %4 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_element_from_dynamic_tensor_from_elements_sideeffects
+// CHECK-SAME: %[[IDX:.*]]: index
+func @extract_element_from_dynamic_tensor_from_elements_sideeffects(%idx: index, %tensor: tensor<*xf32>) -> index {
+  %size = rank %tensor : tensor<*xf32>
+  %mem = alloc(%size) : memref<?xindex>
+  // CHECK: %[[DTENSOR:.*]] = dynamic_tensor_from_elements
+  %0 = dynamic_tensor_from_elements %size {
+    ^bb0(%arg0: index):
+    %1 = dim %tensor, %arg0 : tensor<*xf32>
+    store %1, %mem[%arg0] : memref<?xindex>
+    yield %1 : index
+  } : tensor<?xindex>
+  // CHECK: %[[RES:.*]] = extract_element %[[DTENSOR]][%[[IDX]]]
+  %1 = extract_element %0[%idx] : tensor<?xindex>
+  // CHECK-NEXT: return %[[RES]]
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: @static_dynamic_tensor_from_elements
+// CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index)
+func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> {
+  %c5 = constant 5 : index
+  // CHECK: dynamic_tensor_from_elements %[[SIZE1]], %[[SIZE4]]
+  %0 = dynamic_tensor_from_elements %size1, %c5, %size4 {
+    ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
+    %1 = constant 32 : index
+    yield %1 : index
+  // CHECK: : tensor<3x?x5x7x?xindex>
+  } : tensor<3x?x?x7x?xindex>
+  // CHECK: tensor_cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
+  return %0 : tensor<3x?x?x7x?xindex>
+}
+


        


More information about the Mlir-commits mailing list