[Mlir-commits] [mlir] a2004c3 - [mlir][spirv] Add RewriteInserts pass.

Lei Zhang llvmlistbot at llvm.org
Fri Jun 26 06:57:32 PDT 2020


Author: Denis Khalikov
Date: 2020-06-26T09:57:20-04:00
New Revision: a2004c344bf0028313948e720da35da24bcbb7a9

URL: https://github.com/llvm/llvm-project/commit/a2004c344bf0028313948e720da35da24bcbb7a9
DIFF: https://github.com/llvm/llvm-project/commit/a2004c344bf0028313948e720da35da24bcbb7a9.diff

LOG: [mlir][spirv] Add RewriteInserts pass.

Add a pass to rewrite sequential chains of `spirv::CompositeInsert`
operations into `spirv::CompositeConstruct` operations.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D82198

Added: 
    mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
    mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/Passes.h
    mlir/include/mlir/Dialect/SPIRV/Passes.td
    mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h
index afc60805f75e..df516430be52 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Passes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h
@@ -46,6 +46,10 @@ createUpdateVersionCapabilityExtensionPass();
 ///    functions using the specification in the `spv.entry_point_abi` attribute.
 std::unique_ptr<OperationPass<spirv::ModuleOp>> createLowerABIAttributesPass();
 
+/// Creates an operation pass that rewrites sequential chains of
+/// spv.CompositeInsert into spv.CompositeConstruct.
+std::unique_ptr<OperationPass<spirv::ModuleOp>> createRewriteInsertsPass();
+
 } // namespace spirv
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Passes.td
index e8972f519917..93a3516fc2b3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Passes.td
@@ -22,6 +22,12 @@ def SPIRVLowerABIAttributes : Pass<"spirv-lower-abi-attrs", "spirv::ModuleOp"> {
   let constructor = "mlir::spirv::createLowerABIAttributesPass()";
 }
 
+def SPIRVRewriteInsertsPass : Pass<"spirv-rewrite-inserts", "spirv::ModuleOp"> {
+  let summary = "Rewrite sequential chains of spv.CompositeInsert operations into "
+                "spv.CompositeConstruct operations";
+  let constructor = "mlir::spirv::createRewriteInsertsPass()";
+}
+
 def SPIRVUpdateVCE : Pass<"spirv-update-vce", "spirv::ModuleOp"> {
   let summary = "Deduce and attach minimal (version, capabilities, extensions) "
                 "requirements to spv.module ops";

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index 632194f213d6..228d8482e6be 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRSPIRVTransforms
   DecorateSPIRVCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
+  RewriteInsertsPass.cpp
   UpdateVCEPass.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
new file mode 100644
index 000000000000..56fc4a9b60a6
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -0,0 +1,115 @@
+//===- RewriteInsertsPass.cpp - MLIR conversion pass ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to rewrite sequential chains of
+// `spirv::CompositeInsert` operations into `spirv::CompositeConstruct`
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/SPIRV/Passes.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Module.h"
+
+using namespace mlir;
+
+namespace {
+
+/// Replaces sequential chains of `spirv::CompositeInsertOp` operation into
+/// `spirv::CompositeConstructOp` operation if possible.
+class RewriteInsertsPass
+    : public SPIRVRewriteInsertsPassBase<RewriteInsertsPass> {
+public:
+  void runOnOperation() override;
+
+private:
+  /// Collects a sequential insertion chain by the given
+  /// `spirv::CompositeInsertOp` operation, if the given operation is the last
+  /// in the chain.
+  LogicalResult
+  collectInsertionChain(spirv::CompositeInsertOp op,
+                        SmallVectorImpl<spirv::CompositeInsertOp> &insertions);
+};
+
+} // anonymous namespace
+
+void RewriteInsertsPass::runOnOperation() {
+  SmallVector<SmallVector<spirv::CompositeInsertOp, 4>, 4> workList;
+  getOperation().walk([this, &workList](spirv::CompositeInsertOp op) {
+    SmallVector<spirv::CompositeInsertOp, 4> insertions;
+    if (succeeded(collectInsertionChain(op, insertions)))
+      workList.push_back(insertions);
+  });
+
+  for (const auto &insertions : workList) {
+    auto lastCompositeInsertOp = insertions.back();
+    auto compositeType = lastCompositeInsertOp.getType();
+    auto location = lastCompositeInsertOp.getLoc();
+
+    SmallVector<Value, 4> operands;
+    // Collect inserted objects.
+    for (auto insertionOp : insertions)
+      operands.push_back(insertionOp.object());
+
+    OpBuilder builder(lastCompositeInsertOp);
+    auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
+        location, compositeType, operands);
+
+    lastCompositeInsertOp.replaceAllUsesWith(
+        compositeConstructOp.getOperation()->getResult(0));
+
+    // Erase ops.
+    for (auto insertOp : llvm::reverse(insertions)) {
+      auto *op = insertOp.getOperation();
+      if (op->use_empty())
+        insertOp.erase();
+    }
+  }
+}
+
+LogicalResult RewriteInsertsPass::collectInsertionChain(
+    spirv::CompositeInsertOp op,
+    SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
+  auto indicesArrayAttr = op.indices().cast<ArrayAttr>();
+  // TODO: handle nested composite object.
+  if (indicesArrayAttr.size() == 1) {
+    auto numElements =
+        op.composite().getType().cast<spirv::CompositeType>().getNumElements();
+
+    auto index = indicesArrayAttr[0].cast<IntegerAttr>().getInt();
+    // Need a last index to collect a sequential chain.
+    if (index + 1 != numElements)
+      return failure();
+
+    insertions.resize(numElements);
+    while (true) {
+      insertions[index] = op;
+
+      if (index == 0)
+        return success();
+
+      op = op.composite().getDefiningOp<spirv::CompositeInsertOp>();
+      if (!op)
+        return failure();
+
+      --index;
+      indicesArrayAttr = op.indices().cast<ArrayAttr>();
+      if ((indicesArrayAttr.size() != 1) ||
+          (indicesArrayAttr[0].cast<IntegerAttr>().getInt() != index))
+        return failure();
+    }
+  }
+  return failure();
+}
+
+std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
+mlir::spirv::createRewriteInsertsPass() {
+  return std::make_unique<RewriteInsertsPass>();
+}

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
new file mode 100644
index 000000000000..1b265e3bcd42
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -spirv-rewrite-inserts -split-input-file -verify-diagnostics %s -o - | FileCheck %s
+
+spv.module Logical GLSL450 {
+  spv.func @rewrite(%value0 : f32, %value1 : f32, %value2 : f32, %value3 : i32, %value4: !spv.array<3xf32>) -> vector<3xf32> "None" {
+    %0 = spv.undef : vector<3xf32>
+    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32>
+    %1 = spv.CompositeInsert %value0, %0[0 : i32] : f32 into vector<3xf32>
+    %2 = spv.CompositeInsert %value1, %1[1 : i32] : f32 into vector<3xf32>
+    %3 = spv.CompositeInsert %value2, %2[2 : i32] : f32 into vector<3xf32>
+
+    %4 = spv.undef : !spv.array<4xf32>
+    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spv.array<4 x f32>
+    %5 = spv.CompositeInsert %value0, %4[0 : i32] : f32 into !spv.array<4xf32>
+    %6 = spv.CompositeInsert %value1, %5[1 : i32] : f32 into !spv.array<4xf32>
+    %7 = spv.CompositeInsert %value2, %6[2 : i32] : f32 into !spv.array<4xf32>
+    %8 = spv.CompositeInsert %value0, %7[3 : i32] : f32 into !spv.array<4xf32>
+
+    %9 = spv.undef : !spv.struct<f32, i32, f32>
+    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : !spv.struct<f32, i32, f32>
+    %10 = spv.CompositeInsert %value0, %9[0 : i32] : f32 into !spv.struct<f32, i32, f32>
+    %11 = spv.CompositeInsert %value3, %10[1 : i32] : i32 into !spv.struct<f32, i32, f32>
+    %12 = spv.CompositeInsert %value1, %11[2 : i32] : f32 into !spv.struct<f32, i32, f32>
+
+    %13 = spv.undef : !spv.struct<f32, !spv.array<3xf32>>
+    // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}} : !spv.struct<f32, !spv.array<3 x f32>>
+    %14 = spv.CompositeInsert %value0, %13[0 : i32] : f32 into !spv.struct<f32, !spv.array<3xf32>>
+    %15 = spv.CompositeInsert %value4, %14[1 : i32] : !spv.array<3xf32> into !spv.struct<f32, !spv.array<3xf32>>
+
+    spv.ReturnValue %3 : vector<3xf32>
+  }
+}


        


More information about the Mlir-commits mailing list