[Mlir-commits] [mlir] 0935c05 - [mlir][Vector] Add support for 0-D 'vector.shape_cast' lowering
Diego Caballero
llvmlistbot at llvm.org
Thu Jun 1 15:23:27 PDT 2023
Author: Diego Caballero
Date: 2023-06-01T22:22:16Z
New Revision: 0935c0556bedc35d841103b58eff9a6e3464ffe6
URL: https://github.com/llvm/llvm-project/commit/0935c0556bedc35d841103b58eff9a6e3464ffe6
DIFF: https://github.com/llvm/llvm-project/commit/0935c0556bedc35d841103b58eff9a6e3464ffe6.diff
LOG: [mlir][Vector] Add support for 0-D 'vector.shape_cast' lowering
This PR adds support for shape casting from and to 0-D vectors.
Reviewed By: nicolasvasilache, hanchung, awarzynski
Differential Revision: https://reviews.llvm.org/D151851
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index bd9716cbca94c..f2b28cad76745 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -151,8 +151,26 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
incIdx(srcIdx, sourceVectorType, srcRank - 1);
incIdx(resIdx, resultVectorType, resRank - 1);
}
- Value e = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
- result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
+
+ Value extract;
+ if (srcRank == 0) {
+ // 0-D vector special case
+ assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
+ extract = rewriter.create<vector::ExtractElementOp>(
+ loc, op.getSourceVectorType().getElementType(), op.getSource());
+ } else {
+ extract =
+ rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+ }
+
+ if (resRank == 0) {
+ // 0-D vector special case
+ assert(resIdx.empty() && "Unexpected indices for 0-D vector");
+ result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
+ } else {
+ result =
+ rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
+ }
}
rewriter.replaceOp(op, result);
return success();
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index 716537ed76ff7..f233a17244ff7 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -1,4 +1,3 @@
-
// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s
// CHECK-LABEL: func @nop_shape_cast
@@ -124,9 +123,35 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
return %s : vector<2x1x3xf32>
}
+// CHECK-LABEL: func.func @shape_cast_0d1d(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<f32>) -> vector<1xf32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK: %[[VAL_2:.*]] = vector.extractelement %[[VAL_0]][] : vector<f32>
+// CHECK: %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_1]] [0] : f32 into vector<1xf32>
+// CHECK: return %[[VAL_3]] : vector<1xf32>
+// CHECK: }
+
+func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
+ %s = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
+ return %s : vector<1xf32>
+}
+
+// CHECK-LABEL: func.func @shape_cast_1d0d(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<f32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
+// CHECK: %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : vector<1xf32>
+// CHECK: %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][] : vector<f32>
+// CHECK: return %[[VAL_3]] : vector<f32>
+// CHECK: }
+
+func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
+ %s = vector.shape_cast %arg0 : vector<1xf32> to vector<f32>
+ return %s : vector<f32>
+}
+
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
- %f = transform.structured.match ops{["func.func"]} in %module_op
+ %f = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
%f2 = transform.vector.lower_shape_cast %f
More information about the Mlir-commits
mailing list