[Mlir-commits] [mlir] f16abe5 - [mlir][Vector] Add a folder for vector.broadcast
Hanhan Wang
llvmlistbot at llvm.org
Thu Sep 17 08:55:22 PDT 2020
Author: Hanhan Wang
Date: 2020-09-17T08:54:51-07:00
New Revision: f16abe5f84eee8db18d5eb5a21ab543146626ea6
URL: https://github.com/llvm/llvm-project/commit/f16abe5f84eee8db18d5eb5a21ab543146626ea6
DIFF: https://github.com/llvm/llvm-project/commit/f16abe5f84eee8db18d5eb5a21ab543146626ea6.diff
LOG: [mlir][Vector] Add a folder for vector.broadcast
Fold the operation if the source is a scalar constant or splat constant.
Update transform-patterns-matmul-to-vector.mlir because the broadcast ops are folded in the conversion.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D87703
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 3cb1265b38ce..04aa18cfd648 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -270,6 +270,7 @@ def Vector_BroadcastOp :
}
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
+ let hasFolder = 1;
}
def Vector_ShuffleOp :
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index c2b6f31cf114..c2cfaa54e448 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -929,6 +929,17 @@ static LogicalResult verify(BroadcastOp op) {
return success();
}
+OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
+ if (!operands[0])
+ return {};
+ auto vectorType = getVectorType();
+ if (operands[0].getType().isIntOrIndexOrFloat())
+ return DenseElementsAttr::get(vectorType, operands[0]);
+ if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
+ return DenseElementsAttr::get(vectorType, attr.getSplatValue());
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// ShuffleOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
index 83e9461d66cc..683aeb241318 100644
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -13,13 +13,8 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
}
// CHECK-LABEL:func @matmul
-// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32>
// CHECK: store {{.*}}[] : memref<vector<8x16xf32>>
-//
-// CHECK: vector.broadcast {{.*}} : f32 to vector<16x12xf32>
// CHECK: store {{.*}}[] : memref<vector<16x12xf32>>
-//
-// CHECK: vector.broadcast {{.*}} : f32 to vector<8x12xf32>
// CHECK: store {{.*}}[] : memref<vector<8x12xf32>>
//
// CHECK: linalg.copy
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1b1362f94884..9c36f7684baf 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -385,3 +385,28 @@ func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector<4x8xf
%2 = vector.bitcast %1 : vector<4xi16> to vector<2xi32>
return %0, %2 : vector<4x8xf32>, vector<2xi32>
}
+
+// -----
+
+// CHECK-LABEL: broadcast_folding1
+// CHECK: %[[CST:.*]] = constant dense<42> : vector<4xi32>
+// CHECK-NOT: vector.broadcast
+// CHECK: return %[[CST]]
+func @broadcast_folding1() -> vector<4xi32> {
+ %0 = constant 42 : i32
+ %1 = vector.broadcast %0 : i32 to vector<4xi32>
+ return %1 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_folding2
+// CHECK: %[[CST:.*]] = constant dense<42> : vector<4x16xi32>
+// CHECK-NOT: vector.broadcast
+// CHECK: return %[[CST]]
+func @broadcast_folding2() -> vector<4x16xi32> {
+ %0 = constant 42 : i32
+ %1 = vector.broadcast %0 : i32 to vector<16xi32>
+ %2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
+ return %2 : vector<4x16xi32>
+}
More information about the Mlir-commits
mailing list