[Mlir-commits] [mlir] 92cf9f1 - [mlir][linalg] Cast back to the original type after making linalg.generic outputs more static

Benjamin Kramer llvmlistbot at llvm.org
Thu Feb 24 04:46:11 PST 2022


Author: Benjamin Kramer
Date: 2022-02-24T13:35:54+01:00
New Revision: 92cf9f14814a5e8308c431095fb2205202445676

URL: https://github.com/llvm/llvm-project/commit/92cf9f14814a5e8308c431095fb2205202445676
DIFF: https://github.com/llvm/llvm-project/commit/92cf9f14814a5e8308c431095fb2205202445676.diff

LOG: [mlir][linalg] Cast back to the original type after making linalg.generic outputs more static

This codepath was entirely untested.

Differential Revision: https://reviews.llvm.org/D120473

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 319a1c318ff8a..a33945a873126 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -991,7 +991,7 @@ struct InferStaticShapeOfOperands : public OpRewritePattern<GenericOp> {
       Type oldType = oldResult.getType();
       replacements.push_back(
           (newType != oldType)
-              ? rewriter.create<tensor::CastOp>(loc, newType, newResult)
+              ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
               : newResult);
     }
     rewriter.replaceOp(genericOp, replacements);

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 8a3f201f7cc26..c3405887431ff 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -780,3 +780,27 @@ func @cast_source(%arg0 : tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor
     //  CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
     //  CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
 }
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+// CHECK-LABEL: func @cast_dest
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<1x?x?xf32>,
+func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2: index, %arg3: index, %arg4: index) -> tensor<?x?x?xf32> {
+  %0 = linalg.init_tensor [%arg2, %arg3, %arg4] : tensor<?x?x?xf32>
+  %1 = tensor.cast %arg1 : tensor<1x?x?xf32> to tensor<?x?x?xf32>
+  %2 = linalg.generic {
+    indexing_maps = [#map, #map, #map],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<1x?x?xf32>)
+    outs(%0 : tensor<?x?x?xf32>) {
+  ^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
+    %3 = arith.subf %arg5, %arg6 : f32
+    linalg.yield %3 : f32
+  } -> tensor<?x?x?xf32>
+  return %2 : tensor<?x?x?xf32>
+// CHECK:      %[[GENERIC_OP:.*]] = linalg.generic
+// CHECK-SAME: ins(%{{.*}}, %[[ARG1]] : tensor<1x?x?xf32>, tensor<1x?x?xf32>)
+// CHECK-SAME: outs(%{{.*}} : tensor<1x?x?xf32>)
+// CHECK: tensor.cast %[[GENERIC_OP]] : tensor<1x?x?xf32> to tensor<?x?x?xf32>
+}


        


More information about the Mlir-commits mailing list