[Mlir-commits] [mlir] c4472f8 - [mlir][std] Canonicalize extract_element(tensor_cast).

Stephan Herhut llvmlistbot at llvm.org
Tue Nov 17 05:42:10 PST 2020


Author: Stephan Herhut
Date: 2020-11-17T14:41:39+01:00
New Revision: c4472f8b4cda6a1802cb4543ae2fed94f798ece7

URL: https://github.com/llvm/llvm-project/commit/c4472f8b4cda6a1802cb4543ae2fed94f798ece7
DIFF: https://github.com/llvm/llvm-project/commit/c4472f8b4cda6a1802cb4543ae2fed94f798ece7.diff

LOG: [mlir][std] Canonicalize extract_element(tensor_cast).

Canonicalize extract_element(tensor_cast(v)) to just extract_element(v).

Differential Revision: https://reviews.llvm.org/D91621

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 469889321805..d2a2ca1f83c4 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1945,12 +1945,37 @@ struct ExtractElementFromDynamicTensorFromElements
   }
 };
 
+/// Canonicalizes the pattern of the form
+///
+/// %val = tensor_cast %source : : tensor<?xi32> to tensor<2xi32>
+/// %extracted_element = extract_element %val[%c0] : tensor<2xi32>
+///
+/// to
+///
+/// %extracted_element = extract_element %source[%c0] : tensor<?xi32>
+struct ExtractElementFromTensorCast
+    : public OpRewritePattern<ExtractElementOp> {
+  using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractElementOp extract,
+                                PatternRewriter &rewriter) const final {
+    auto tensorCast = extract.aggregate().getDefiningOp<TensorCastOp>();
+    if (!tensorCast)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<ExtractElementOp>(extract, tensorCast.source(),
+                                                  extract.getIndices());
+    return success();
+  }
+};
+
 } // namespace
 
 void DynamicTensorFromElementsOp::getCanonicalizationPatterns(
     OwningRewritePatternList &results, MLIRContext *context) {
   results.insert<ExtractElementFromDynamicTensorFromElements,
-                 StaticDynamicTensorFromElements>(context);
+                 ExtractElementFromTensorCast, StaticDynamicTensorFromElements>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 08f3ac702596..4a74f5438a35 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1202,3 +1202,17 @@ func @subtensor(%t: tensor<8x16x4xf32>, %arg0 : index, %arg1 : index)
 
   return %2 : tensor<?x?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @extract_element_from_tensor_cast
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>
+func @extract_element_from_tensor_cast(%tensor: tensor<*xf32>) -> f32 {
+  // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+  %c0 = constant 0 : index
+  // CHECK-NOT: tensor_cast
+  %casted = tensor_cast %tensor : tensor<*xf32> to tensor<?xf32>
+  // CHECK-NEXT: extract_element %[[TENSOR]][%[[C0]]]
+  %result = extract_element %casted[%c0] : tensor<?xf32>
+  return %result : f32
+}


        


More information about the Mlir-commits mailing list