[Mlir-commits] [mlir] a2004c3 - [mlir][spirv] Add RewriteInserts pass.
Lei Zhang
llvmlistbot at llvm.org
Fri Jun 26 06:57:32 PDT 2020
Author: Denis Khalikov
Date: 2020-06-26T09:57:20-04:00
New Revision: a2004c344bf0028313948e720da35da24bcbb7a9
URL: https://github.com/llvm/llvm-project/commit/a2004c344bf0028313948e720da35da24bcbb7a9
DIFF: https://github.com/llvm/llvm-project/commit/a2004c344bf0028313948e720da35da24bcbb7a9.diff
LOG: [mlir][spirv] Add RewriteInserts pass.
Add a pass to rewrite sequential chains of `spirv::CompositeInsert`
operations into `spirv::CompositeConstruct` operations.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D82198
Added:
mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/Passes.h
mlir/include/mlir/Dialect/SPIRV/Passes.td
mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h
index afc60805f75e..df516430be52 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Passes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h
@@ -46,6 +46,10 @@ createUpdateVersionCapabilityExtensionPass();
/// functions using the specification in the `spv.entry_point_abi` attribute.
std::unique_ptr<OperationPass<spirv::ModuleOp>> createLowerABIAttributesPass();
+/// Creates an operation pass that rewrites sequential chains of
+/// spv.CompositeInsert into spv.CompositeConstruct.
+std::unique_ptr<OperationPass<spirv::ModuleOp>> createRewriteInsertsPass();
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Passes.td
index e8972f519917..93a3516fc2b3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Passes.td
@@ -22,6 +22,12 @@ def SPIRVLowerABIAttributes : Pass<"spirv-lower-abi-attrs", "spirv::ModuleOp"> {
let constructor = "mlir::spirv::createLowerABIAttributesPass()";
}
+def SPIRVRewriteInsertsPass : Pass<"spirv-rewrite-inserts", "spirv::ModuleOp"> {
+ let summary = "Rewrite sequential chains of spv.CompositeInsert operations into "
+ "spv.CompositeConstruct operations";
+ let constructor = "mlir::spirv::createRewriteInsertsPass()";
+}
+
def SPIRVUpdateVCE : Pass<"spirv-update-vce", "spirv::ModuleOp"> {
let summary = "Deduce and attach minimal (version, capabilities, extensions) "
"requirements to spv.module ops";
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 632194f213d6..228d8482e6be 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRSPIRVTransforms
DecorateSPIRVCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
+ RewriteInsertsPass.cpp
UpdateVCEPass.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
new file mode 100644
index 000000000000..56fc4a9b60a6
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -0,0 +1,115 @@
+//===- RewriteInsertsPass.cpp - MLIR conversion pass ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to rewrite sequential chains of
+// `spirv::CompositeInsert` operations into `spirv::CompositeConstruct`
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Module.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Replaces sequential chains of `spirv::CompositeInsertOp` operation into
+/// `spirv::CompositeConstructOp` operation if possible.
+class RewriteInsertsPass
+ : public SPIRVRewriteInsertsPassBase<RewriteInsertsPass> {
+public:
+ void runOnOperation() override;
+
+private:
+ /// Collects a sequential insertion chain by the given
+ /// `spirv::CompositeInsertOp` operation, if the given operation is the last
+ /// in the chain.
+ LogicalResult
+ collectInsertionChain(spirv::CompositeInsertOp op,
+ SmallVectorImpl<spirv::CompositeInsertOp> &insertions);
+};
+
+} // anonymous namespace
+
+void RewriteInsertsPass::runOnOperation() {
+ SmallVector<SmallVector<spirv::CompositeInsertOp, 4>, 4> workList;
+ getOperation().walk([this, &workList](spirv::CompositeInsertOp op) {
+ SmallVector<spirv::CompositeInsertOp, 4> insertions;
+ if (succeeded(collectInsertionChain(op, insertions)))
+ workList.push_back(insertions);
+ });
+
+ for (const auto &insertions : workList) {
+ auto lastCompositeInsertOp = insertions.back();
+ auto compositeType = lastCompositeInsertOp.getType();
+ auto location = lastCompositeInsertOp.getLoc();
+
+ SmallVector<Value, 4> operands;
+ // Collect inserted objects.
+ for (auto insertionOp : insertions)
+ operands.push_back(insertionOp.object());
+
+ OpBuilder builder(lastCompositeInsertOp);
+ auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
+ location, compositeType, operands);
+
+ lastCompositeInsertOp.replaceAllUsesWith(
+ compositeConstructOp.getOperation()->getResult(0));
+
+ // Erase ops.
+ for (auto insertOp : llvm::reverse(insertions)) {
+ auto *op = insertOp.getOperation();
+ if (op->use_empty())
+ insertOp.erase();
+ }
+ }
+}
+
+LogicalResult RewriteInsertsPass::collectInsertionChain(
+ spirv::CompositeInsertOp op,
+ SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
+ auto indicesArrayAttr = op.indices().cast<ArrayAttr>();
+ // TODO: handle nested composite object.
+ if (indicesArrayAttr.size() == 1) {
+ auto numElements =
+ op.composite().getType().cast<spirv::CompositeType>().getNumElements();
+
+ auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt();
+ // Need a last index to collect a sequential chain.
+ if (index + 1 != numElements)
+ return failure();
+
+ insertions.resize(numElements);
+ while (true) {
+ insertions[index] = op;
+
+ if (index == 0)
+ return success();
+
+ op = op.composite().getDefiningOp<spirv::CompositeInsertOp>();
+ if (!op)
+ return failure();
+
+ --index;
+ indicesArrayAttr = op.indices().cast<ArrayAttr>();
+ if ((indicesArrayAttr.size() != 1) ||
+ (indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index))
+ return failure();
+ }
+ }
+ return failure();
+}
+
+std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
+mlir::spirv::createRewriteInsertsPass() {
+ return std::make_unique<RewriteInsertsPass>();
+}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
new file mode 100644
index 000000000000..1b265e3bcd42
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -spirv-rewrite-inserts -split-input-file -verify-diagnostics %s -o - | FileCheck %s
+
+spv.module Logical GLSL450 {
+ spv.func @rewrite(%value0 : f32, %value1 : f32, %value2 : f32, %value3 : i32, %value4: !spv.array<3xf32>) -> vector<3xf32> "None" {
+ %0 = spv.undef : vector<3xf32>
+ // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32>
+ %1 = spv.CompositeInsert %value0, %0[0 : i32] : f32 into vector<3xf32>
+ %2 = spv.CompositeInsert %value1, %1[1 : i32] : f32 into vector<3xf32>
+ %3 = spv.CompositeInsert %value2, %2[2 : i32] : f32 into vector<3xf32>
+
+ %4 = spv.undef : !spv.array<4xf32>
+ // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spv.array<4 x f32>
+ %5 = spv.CompositeInsert %value0, %4[0 : i32] : f32 into !spv.array<4xf32>
+ %6 = spv.CompositeInsert %value1, %5[1 : i32] : f32 into !spv.array<4xf32>
+ %7 = spv.CompositeInsert %value2, %6[2 : i32] : f32 into !spv.array<4xf32>
+ %8 = spv.CompositeInsert %value0, %7[3 : i32] : f32 into !spv.array<4xf32>
+
+ %9 = spv.undef : !spv.struct<f32, i32, f32>
+ // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : !spv.struct<f32, i32, f32>
+ %10 = spv.CompositeInsert %value0, %9[0 : i32] : f32 into !spv.struct<f32, i32, f32>
+ %11 = spv.CompositeInsert %value3, %10[1 : i32] : i32 into !spv.struct<f32, i32, f32>
+ %12 = spv.CompositeInsert %value1, %11[2 : i32] : f32 into !spv.struct<f32, i32, f32>
+
+ %13 = spv.undef : !spv.struct<f32, !spv.array<3xf32>>
+ // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}} : !spv.struct<f32, !spv.array<3 x f32>>
+ %14 = spv.CompositeInsert %value0, %13[0 : i32] : f32 into !spv.struct<f32, !spv.array<3xf32>>
+ %15 = spv.CompositeInsert %value4, %14[1 : i32] : !spv.array<3xf32> into !spv.struct<f32, !spv.array<3xf32>>
+
+ spv.ReturnValue %3 : vector<3xf32>
+ }
+}
More information about the Mlir-commits
mailing list