[Mlir-commits] [mlir] b4130e9 - [MLIR][PDL] Integration test of multi-root matching and related fixes.

Uday Bondhugula llvmlistbot at llvm.org
Mon Jan 3 18:40:27 PST 2022


Author: Stanislav Funiak
Date: 2022-01-04T08:03:45+05:30
New Revision: b4130e9eadfe46b4d3380c40ce8c3e900a0fd21b

URL: https://github.com/llvm/llvm-project/commit/b4130e9eadfe46b4d3380c40ce8c3e900a0fd21b
DIFF: https://github.com/llvm/llvm-project/commit/b4130e9eadfe46b4d3380c40ce8c3e900a0fd21b.diff

LOG: [MLIR][PDL] Integration test of multi-root matching and related fixes.

This diff adds an integration test to multi-root PDL matching. It consists of two subtests:
1) A 1-layer perceptron with split forward / backward operations.
2) A 2-layer perceptron with fused forward / backward operations.

These tests use a collection of hand-written patterns and TensorFlow operations to be matched. The first test has a DAG / SSA dominant resulting match; the second does not and is therefore stored in a graph region.

This diff also includes two bug fixes:
1) Mark the pdl_interp dialect as a dependent in the TestPDLByteCodePass. This is needed, because we create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
2) Fix of the starting index in the liveness range for the ForEach operations (bug exposed by the integration test).

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D116082

Added: 
    mlir/test/Integration/Dialect/PDL/CPU/multiroot.mlir

Modified: 
    mlir/lib/Rewrite/ByteCode.cpp
    mlir/test/lib/Rewrite/TestPDLByteCode.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 765c47b2ed0cf..d6a07f9067fe4 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -551,10 +551,22 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
   // finding the minimal number of overlapping live ranges. This is essentially
   // a simplified form of register allocation where we don't necessarily have a
   // limited number of registers, but we still want to minimize the number used.
-  DenseMap<Operation *, unsigned> opToIndex;
-  matcherFunc.getBody().walk([&](Operation *op) {
-    opToIndex.insert(std::make_pair(op, opToIndex.size()));
-  });
+  DenseMap<Operation *, unsigned> opToFirstIndex;
+  DenseMap<Operation *, unsigned> opToLastIndex;
+
+  // A custom walk that marks the first and the last index of each operation.
+  // The entry marks the beginning of the liveness range for this operation,
+  // followed by nested operations, followed by the end of the liveness range.
+  unsigned index = 0;
+  llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
+    opToFirstIndex.try_emplace(op, index++);
+    for (Region &region : op->getRegions())
+      for (Block &block : region.getBlocks())
+        for (Operation &nested : block)
+          walk(&nested);
+    opToLastIndex.try_emplace(op, index++);
+  };
+  walk(matcherFunc);
 
   // Liveness info for each of the defs within the matcher.
   ByteCodeLiveRange::Allocator allocator;
@@ -578,8 +590,8 @@ void Generator::allocateMemoryIndices(FuncOp matcherFunc,
       // Set indices for the range of this block that the value is used.
       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
       defRangeIt->second.liveness->insert(
-          opToIndex[firstUseOrDef],
-          opToIndex[info->getEndOperation(value, firstUseOrDef)],
+          opToFirstIndex[firstUseOrDef],
+          opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
           /*dummyValue*/ 0);
 
       // Check to see if this value is a range type.

diff  --git a/mlir/test/Integration/Dialect/PDL/CPU/multiroot.mlir b/mlir/test/Integration/Dialect/PDL/CPU/multiroot.mlir
new file mode 100644
index 0000000000000..be496ed3a675c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/PDL/CPU/multiroot.mlir
@@ -0,0 +1,294 @@
+// RUN: mlir-opt %s  -allow-unregistered-dialect -test-pdl-bytecode-pass -split-input-file | FileCheck %s
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// 1-layer perceptron with split fwd/bwd operations
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+  // fc_fwd
+  pdl.pattern : benefit(1) {
+    %in_type = pdl.type
+    %out_type = pdl.type
+    %weight_type = pdl.type
+    %rxact = pdl.operand : %in_type
+    %weight = pdl.operand : %weight_type
+
+    %attr0 = pdl.attribute false
+    %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
+
+    pdl.rewrite %op0 {
+      %op1 = pdl.operation "kernel.FcFwd" (%rxact, %weight : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+      %val1 = pdl.result 0 of %op1  // txact
+      pdl.replace %op0 with (%val1 : !pdl.value)  // tf.MatMul
+    }
+  }
+
+  // fc_bwd
+  pdl.pattern : benefit(4) {
+    %in_type = pdl.type
+    %out_type = pdl.type
+    %weight_type = pdl.type
+    %const_type = pdl.type
+    %rxact = pdl.operand : %in_type
+    %rxdelta = pdl.operand : %out_type
+    %weight = pdl.operand : %weight_type
+
+    %attr0 = pdl.attribute true
+    %attr1 = pdl.attribute false
+    %op0 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%weight_type : !pdl.type)
+    %val0 = pdl.result 0 of %op0
+    %op1 = pdl.operation "tf.Const" -> (%const_type : !pdl.type)
+    %val1 = pdl.result 0 of %op1
+    %op2 = pdl.operation "tf.Mul" (%val0, %val1 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
+    %val2 = pdl.result 0 of %op2
+    %op3 = pdl.operation "tf.Sub" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
+
+    pdl.rewrite %op3 {
+      %op4 = pdl.operation "kernel.FcBwd" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
+      %val4 = pdl.result 0 of %op4  // weight_out
+      pdl.replace %op3 with (%val4 : !pdl.value)  // tf.Sub
+      pdl.erase %op2  // tf.Mul
+      pdl.erase %op1  // tf.Const
+      pdl.erase %op0  // tf.MatMul
+    }
+  }
+
+  // softmax_cross_entropy
+  pdl.pattern : benefit(6) {
+    %in_type = pdl.type
+    %label_type = pdl.type
+    %loss_type = pdl.type
+    %mean_loss_type = pdl.type
+    %mean_const_type = pdl.type
+    %mul_const_type = pdl.type
+    %rxact = pdl.operand : %in_type
+    %rxlabel = pdl.operand : %label_type
+
+    %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type)
+    %val0_0 = pdl.result 0 of %op0  // loss
+    %val0_1 = pdl.result 1 of %op0  // gradient
+    %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type)
+    %val1 = pdl.result 0 of %op1
+    %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type)
+    %val2 = pdl.result 0 of %op2
+    %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type)
+    %val3 = pdl.result 0 of %op3
+    %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type)
+    %val4 = pdl.result 0 of %op4
+    %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type)
+
+    pdl.rewrite {  // roots: %op2, %op5
+      %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type)
+      %val6_0 = pdl.result 0 of %op6  // txloss
+      %val6_1 = pdl.result 1 of %op6  // txdelta
+      pdl.replace %op5 with (%val6_1 : !pdl.value)  // tf.Mul
+      pdl.erase %op4  // tf.Const
+      pdl.erase %op3  // tf.PreventGradient
+      pdl.replace %op2 with (%val6_0 : !pdl.value)  // tf.Mean
+      pdl.erase %op1  // tf.Const
+      pdl.erase %op0  // tf.SparseSoftmaxCrossEntropyWithLogits
+    }
+  }
+}
+
+// CHECK-LABEL: test.mlp_split
+// CHECK: %[[FWD:.*]] = "kernel.FcFwd"(%arg0, %arg2) : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32>
+// CHECK: %[[SM:.*]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FWD]], %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<f32>, tensor<2x10xf32>)
+// CHECK: %[[BWD:.*]] = "kernel.FcBwd"(%arg0, %[[SM]]#1, %arg2) : (tensor<2x20xf32>, tensor<2x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32>
+// CHECK: return %[[SM:.*]]#0, %[[BWD]] : tensor<f32>, tensor<20x10xf32>
+module @ir attributes { test.mlp_split } {
+  func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<20x10xf32>) -> (tensor<f32>, tensor<20x10xf32>) {
+    %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+    %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
+    %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
+    %3 = "tf.MatMul"(%arg0, %arg2) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32>
+    %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%3, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>)
+    %4 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
+    %5 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32>
+    %6 = "tf.Mul"(%5, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32>
+    %7 = "tf.MatMul"(%arg0, %6) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x10xf32>) -> tensor<20x10xf32>
+    %8 = "tf.Mul"(%7, %1) : (tensor<20x10xf32>, tensor<f32>) -> tensor<20x10xf32>
+    %9 = "tf.Sub"(%arg2, %8) : (tensor<20x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32>
+    return %4, %9 : tensor<f32>, tensor<20x10xf32>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// 2-layer perceptron with fused fwd/bwd operations
+//===----------------------------------------------------------------------===//
+
+module @patterns {
+
+  // gradient descent
+  pdl.pattern : benefit(3) {
+    %const_type = pdl.type
+    %param_type = pdl.type
+    %param = pdl.operand : %param_type
+    %gradient = pdl.operand : %param_type
+
+    %attr0 = pdl.attribute
+    %op0 = pdl.operation "tf.Const" {"value" = %attr0} -> (%const_type : !pdl.type)
+    %val0 = pdl.result 0 of %op0
+    %op1 = pdl.operation "tf.Mul" (%gradient, %val0 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
+    %val1 = pdl.result 0 of %op1
+    %op2 = pdl.operation "tf.Sub" (%param, %val1 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
+
+    pdl.rewrite %op2 {
+      %op3 = pdl.operation "kernel.GD" (%param, %gradient : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
+      %val3 = pdl.result 0 of %op3
+      pdl.replace %op2 with (%val3 : !pdl.value)  // tf.Sub
+      pdl.erase %op1  // tf.Mul
+    }
+  }
+
+  // first FC
+  pdl.pattern : benefit(8) {
+    %in_type = pdl.type
+    %out_type = pdl.type
+    %weight_type = pdl.type
+    %bias_type = pdl.type
+    %rxact = pdl.operand : %in_type
+    %rxdelta = pdl.operand : %out_type
+    %weight = pdl.operand : %weight_type
+    %bias = pdl.operand : %bias_type
+
+    %attr0 = pdl.attribute false
+    %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
+    %val0 = pdl.result 0 of %op0
+    %op1 = pdl.operation "tf.BiasAdd" (%val0, %bias : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+    %val1 = pdl.result 0 of %op1
+    %op2 = pdl.operation "tf.Relu" (%val1 : !pdl.value) -> (%out_type : !pdl.type)
+    %val2 = pdl.result 0 of %op2
+    %op3 = pdl.operation "tf.ReluGrad" (%rxdelta, %val2 : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
+    %val3 = pdl.result 0 of %op3
+    %attr1 = pdl.attribute true
+    %op4 = pdl.operation "tf.MatMul" (%rxact, %val3 : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type)
+    %val4 = pdl.result 0 of %op4
+    %op5 = pdl.operation "kernel.GD" (%weight, %val4 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
+    %op6 = pdl.operation "tf.BiasAddGrad" (%val3 : !pdl.value) -> (%bias_type : !pdl.type)
+    %val6 = pdl.result 0 of %op6
+    %op7 = pdl.operation "kernel.GD" (%bias, %val6 : !pdl.value, !pdl.value) -> (%bias_type : !pdl.type)
+
+    pdl.rewrite {  // roots: %op2, %op5, %op7
+      %op8 = pdl.operation "kernel.FcWithBias" (%rxact, %rxdelta, %weight, %bias : !pdl.value, !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %weight_type, %bias_type : !pdl.type, !pdl.type, !pdl.type)
+      %val8_0 = pdl.result 0 of %op8  // txact
+      %val8_1 = pdl.result 1 of %op8  // weight_out
+      %val8_2 = pdl.result 2 of %op8  // bias_out
+      pdl.replace %op7 with (%val8_2 : !pdl.value)  // kernel.GD
+      pdl.erase %op6  // tf.BiasAddGrad
+      pdl.replace %op5 with (%val8_1 : !pdl.value)  // kernel.GD
+      pdl.erase %op4  // tf.MatMul
+      pdl.erase %op3  // tf.ReluGrad
+      pdl.replace %op2 with (%val8_0 : !pdl.value)  // tf.Relu
+      pdl.erase %op1  // tf.BiasAdd
+      pdl.erase %op0  // tf.MatMul
+    }
+  }
+
+  // second FC
+  pdl.pattern : benefit(4) {
+    %in_type = pdl.type
+    %out_type = pdl.type
+    %weight_type = pdl.type
+    %rxact = pdl.operand : %in_type
+    %rxdelta = pdl.operand : %out_type
+    %weight = pdl.operand : %weight_type
+
+    %attr0 = pdl.attribute false
+    %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
+    %attr1 = pdl.attribute true
+    %op1 = pdl.operation "tf.MatMul" (%rxdelta, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%in_type : !pdl.type)
+    %op2 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type)
+    %val2 = pdl.result 0 of %op2
+    %op3 = pdl.operation "kernel.GD" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
+
+    pdl.rewrite {  // roots: %op0, %op1, %op3
+      %op4 = pdl.operation "kernel.Fc" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %in_type, %weight_type : !pdl.type, !pdl.type, !pdl.type)
+      %val4_0 = pdl.result 0 of %op4  // txact
+      %val4_1 = pdl.result 1 of %op4  // txdelta
+      %val4_2 = pdl.result 2 of %op4  // weight_out
+      pdl.replace %op3 with (%val4_2 : !pdl.value)  // Sgd
+      pdl.erase %op2  // tf.MatMul
+      pdl.replace %op1 with (%val4_1 : !pdl.value)  // tf.MatMul
+      pdl.replace %op0 with (%val4_0 : !pdl.value)  // tf.MatMul
+    }
+  }
+
+  // softmax_cross_entropy
+  pdl.pattern : benefit(6) {
+    %in_type = pdl.type
+    %label_type = pdl.type
+    %loss_type = pdl.type
+    %mean_loss_type = pdl.type
+    %mean_const_type = pdl.type
+    %mul_const_type = pdl.type
+    %rxact = pdl.operand : %in_type
+    %rxlabel = pdl.operand : %label_type
+
+    %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type)
+    %val0_0 = pdl.result 0 of %op0  // loss
+    %val0_1 = pdl.result 1 of %op0  // gradient
+    %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type)
+    %val1 = pdl.result 0 of %op1
+    %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type)
+    %val2 = pdl.result 0 of %op2
+    %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type)
+    %val3 = pdl.result 0 of %op3
+    %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type)
+    %val4 = pdl.result 0 of %op4
+    %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type)
+
+    pdl.rewrite {  // roots: %op2, %op5
+      %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type)
+      %val6_0 = pdl.result 0 of %op6  // txloss
+      %val6_1 = pdl.result 1 of %op6  // txdelta
+      pdl.replace %op5 with (%val6_1 : !pdl.value)  // tf.Mul
+      pdl.erase %op4  // tf.Const
+      pdl.erase %op3  // tf.PreventGradient
+      pdl.replace %op2 with (%val6_0 : !pdl.value)  // tf.Mean
+      pdl.erase %op1  // tf.Const
+      pdl.erase %op0  // tf.SparseSoftmaxCrossEntropyWithLogits
+    }
+  }
+}
+
+// CHECK-LABEL: test.mlp_fused
+// CHECK: %[[FC2:.*]]:3 = "kernel.Fc"(%[[FC1:.*]]#0, %[[SM:.*]]#1, %arg4) : (tensor<2x256xf32>, tensor<2x10xf32>, tensor<256x10xf32>) -> (tensor<2x10xf32>, tensor<2x256xf32>, tensor<256x10xf32>)
+// CHECK: %[[SM]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FC2]]#0, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<f32>, tensor<2x10xf32>)
+// CHECK: %[[FC1]]:3 = "kernel.FcWithBias"(%arg0, %[[FC2]]#1, %arg3, %arg2) : (tensor<2x20xf32>, tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>) -> (tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>)
+module @ir attributes { test.mlp_fused } {
+  func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<256xf32>, %arg3: tensor<20x256xf32>, %arg4: tensor<256x10xf32>) -> () { // tensor<f32>, tensor<256xf32>, tensor<20x256xf32>, tensor<256x10xf32>) {
+    // The replacement operations fuse forward and backward pass; therefore, the
+    // resulting graph is not a DAG. To address this, we wrap the operations in
+    // a graph region.
+    "test.graph_region"() ({
+      %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+      %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
+      %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
+      %3 = "tf.MatMul"(%arg0, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x256xf32>) -> tensor<2x256xf32>
+      %4 = "tf.BiasAdd"(%3, %arg2) {data_format = "NHWC"} : (tensor<2x256xf32>, tensor<256xf32>) -> tensor<2x256xf32>
+      %5 = "tf.Relu"(%4) : (tensor<2x256xf32>) -> tensor<2x256xf32>
+      %6 = "tf.MatMul"(%5, %arg4) {transpose_a = false, transpose_b = false} : (tensor<2x256xf32>, tensor<256x10xf32>) -> tensor<2x10xf32>
+      %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%6, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>)
+      %7 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
+      %8 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32>
+      %9 = "tf.Mul"(%8, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32>
+      %10 = "tf.MatMul"(%9, %arg4) {transpose_a = false, transpose_b = true} : (tensor<2x10xf32>, tensor<256x10xf32>) -> tensor<2x256xf32>
+      %11 = "tf.MatMul"(%5, %9) {transpose_a = true, transpose_b = false} : (tensor<2x256xf32>, tensor<2x10xf32>) -> tensor<256x10xf32>
+      %12 = "tf.ReluGrad"(%10, %5) : (tensor<2x256xf32>, tensor<2x256xf32>) -> tensor<2x256xf32>
+      %13 = "tf.BiasAddGrad"(%12) {data_format = "NHWC"} : (tensor<2x256xf32>) -> tensor<256xf32>
+      %14 = "tf.MatMul"(%arg0, %12) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x256xf32>) -> tensor<20x256xf32>
+      %15 = "tf.Mul"(%14, %1) : (tensor<20x256xf32>, tensor<f32>) -> tensor<20x256xf32>
+      %16 = "tf.Sub"(%arg3, %15) : (tensor<20x256xf32>, tensor<20x256xf32>) -> tensor<20x256xf32>
+      %17 = "tf.Mul"(%13, %1) : (tensor<256xf32>, tensor<f32>) -> tensor<256xf32>
+      %18 = "tf.Sub"(%arg2, %17) : (tensor<256xf32>, tensor<256xf32>) -> tensor<256xf32>
+      %19 = "tf.Mul"(%11, %1) : (tensor<256x10xf32>, tensor<f32>) -> tensor<256x10xf32>
+      %20 = "tf.Sub"(%arg4, %19) : (tensor<256x10xf32>, tensor<256x10xf32>) -> tensor<256x10xf32>
+    }) : () -> ()
+    return
+  }
+}

diff  --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index ef62d73978d8b..748e54822718a 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -75,6 +76,11 @@ struct TestPDLByteCodePass
   StringRef getDescription() const final {
     return "Test PDL ByteCode functionality";
   }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    // Mark the pdl_interp dialect as a dependent. This is needed, because we
+    // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
+    registry.insert<pdl_interp::PDLInterpDialect>();
+  }
   void runOnOperation() final {
     ModuleOp module = getOperation();
 


        


More information about the Mlir-commits mailing list