[Mlir-commits] [mlir] 79f6f92 - [mlir][spirv] Enhance folding capability of spirv::CompositeExtractOp::fold
Jakub Kuderski
llvmlistbot at llvm.org
Fri May 26 16:23:19 PDT 2023
Author: Nishant Patel
Date: 2023-05-26T19:23:02-04:00
New Revision: 79f6f92e10b7cebe2d73f3bbe493ee1392779e26
URL: https://github.com/llvm/llvm-project/commit/79f6f92e10b7cebe2d73f3bbe493ee1392779e26
DIFF: https://github.com/llvm/llvm-project/commit/79f6f92e10b7cebe2d73f3bbe493ee1392779e26.diff
LOG: [mlir][spirv] Enhance folding capability of spirv::CompositeExtractOp::fold
This PR improves the `spirv::CompositeExtractOp::fold` function by adding a backtracking mechanism.
The updated function can now traverse a chain of `CompositeInsertOp`s to find a match.
Patch By: nbpatel
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D151536
Added:
Modified:
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index cb4ae4efe6062..9219e31f11692 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -141,19 +141,24 @@ OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
//===----------------------------------------------------------------------===//
OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
- if (auto insertOp =
- getComposite().getDefiningOp<spirv::CompositeInsertOp>()) {
+ Value compositeOp = getComposite();
+
+ while (auto insertOp =
+ compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
if (getIndices() == insertOp.getIndices())
return insertOp.getObject();
+ compositeOp = insertOp.getComposite();
}
if (auto constructOp =
- getComposite().getDefiningOp<spirv::CompositeConstructOp>()) {
+ compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
if (getIndices().size() == 1 &&
constructOp.getConstituents().size() == type.getNumElements()) {
- auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
- return constructOp.getConstituents()[i.getValue().getSExtValue()];
+ auto i = getIndices().begin()->cast<IntegerAttr>();
+ if (static_cast<size_t>(i.getValue().getSExtValue()) <
+ constructOp.getConstituents().size())
+ return constructOp.getConstituents()[i.getValue().getSExtValue()];
}
}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index f543ed44f649d..d5dc97a6245b1 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -185,6 +185,31 @@ func.func @extract_construct(%val1: vector<2xf32>, %val2: vector<2xf32>) -> (vec
return %1, %2 : vector<2xf32>, vector<2xf32>
}
+// -----
+
+ // CHECK-LABEL: fold_composite_op
+ // CHECK-SAME: (%[[COMP:.+]]: !spirv.struct<(f32, f32)>, %[[VAL1:.+]]: f32, %[[VAL2:.+]]: f32)
+ func.func @fold_composite_op(%composite: !spirv.struct<(f32, f32)>, %val1: f32, %val2: f32) -> f32 {
+ %insert = spirv.CompositeInsert %val1, %composite[0 : i32] : f32 into !spirv.struct<(f32, f32)>
+ %1 = spirv.CompositeInsert %val2, %insert[1 : i32] : f32 into !spirv.struct<(f32, f32)>
+ %2 = spirv.CompositeExtract %1[0 : i32] : !spirv.struct<(f32, f32)>
+ // CHECK-NEXT: return %[[VAL1]]
+ return %2 : f32
+ }
+
+// -----
+
+ // CHECK-LABEL: fold_composite_op
+ // CHECK-SAME: (%[[VAL1:.+]]: f32, %[[VAL2:.+]]: f32, %[[VAL3:.+]]: f32)
+ func.func @fold_composite_op(%val1: f32, %val2: f32, %val3: f32) -> f32 {
+ %composite = spirv.CompositeConstruct %val1, %val1, %val1 : (f32, f32, f32) -> !spirv.struct<(f32, f32, f32)>
+ %insert = spirv.CompositeInsert %val2, %composite[1 : i32] : f32 into !spirv.struct<(f32, f32, f32)>
+ %1 = spirv.CompositeInsert %val3, %insert[2 : i32] : f32 into !spirv.struct<(f32, f32, f32)>
+ %2 = spirv.CompositeExtract %1[0 : i32] : !spirv.struct<(f32, f32, f32)>
+ // CHECK-NEXT: return %[[VAL1]]
+ return %2 : f32
+ }
+
// -----
// Not yet implemented case
More information about the Mlir-commits
mailing list