# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from fusion_base import Fusion
from fusion_utils import NumpyHelper
from onnx_model import OnnxModel
logger = getLogger(__name__)
class FusionConstantFold(Fusion):
def __init__(self, model: OnnxModel):
super().__init__(model, "", ["Transpose"])
self.count = 0
def apply(self):
super().apply()
if self.count > 0:
logger.info(f"Constant Folded: {self.count}")
def fuse(self, node, input_name_to_nodes, output_name_to_node):
"""
Apply multiple fusions on Transpose nodes that can be constant folded.
"""
self.fuse_1(node, input_name_to_nodes, output_name_to_node)
self.fuse_2(node, input_name_to_nodes, output_name_to_node)
def fuse_1(self, node, input_name_to_nodes, output_name_to_node):
"""
Constant fold any initializer data representing a MatMul's
weights that are stored in a Transpose op
Ex: Transpose --> Gemm or Transpose --> MatMul
"""
# Check if Transpose node only has one input and one output
if len(node.input) != 1 or len(node.output) != 1:
logger.debug("fuse_constant_fold: node has more than one input or output")
return
# Check if input is initializer data
proto = self.model.get_initializer(node.input[0])
if proto is None:
logger.debug("fuse_constant_fold: failed to identify initializer input")
return
# Check that all nodes using input are Transpose ops that also only use the initializer data as input
skip = False
for child_node in input_name_to_nodes[node.input[0]]:
if not (child_node.op_type == "Transpose" and len(node.input) == 1):
skip = True
break
if skip:
logger.debug("fuse_constant_fold: other non-Transpose nodes use the initializer")
return
# Check that all nodes using output are Gemm or MatMul ops
for child_node in input_name_to_nodes[node.output[0]]:
if not (child_node.op_type == "Gemm" or child_node.op_type == "MatMul"):
skip = True
break
if skip:
logger.debug("fuse_constant_fold: other non-Gemm and non-MatMul nodes use the transposed data")
return
# Check if initializer data is 2D
weight = NumpyHelper.to_array(proto)
if len(weight.shape) != 2:
logger.debug("fuse_constant_fold: shape of initializer data is not 2D")
return
# Remove old TensorProto and add new TensorProto while re-using same name
name = proto.name
dtype = proto.data_type
self.remove_initializer(proto)
self.add_initializer(
name=name,
data_type=dtype,
dims=[weight.shape[1], weight.shape[0]],
vals=weight.T,
)
# Update weights input to be the initializer name and not
# the output of the Transpose op
for child_node in input_name_to_nodes[node.output[0]]:
for i in range(len(child_node.input)):
if child_node.input[i] == node.output[0]:
child_node.input[i] = node.input[0]
if child_node.op_type == "Gemm" and (i == 0 or i == 1):
# Ensure that transA/transB is set to 0 in Gemm
key = "transA" if i == 0 else "transB"
for j, attr_key in enumerate(child_node.attribute):
if attr_key.name == key:
child_node.attribute[j].i = 0
# Add node to list of nodes to remove
self.nodes_to_remove.append(node)
self.count += 1
def fuse_2(self, node, input_name_to_nodes, output_name_to_node):
"""
Constant fold any Transpose --> Transpose ops since the root input
is the final result
Ex: root_input --> Transpose --> Transpose --> next_node to root_input --> next_node
"""
# Check if Transpose node only has one input and one output
if len(node.input) != 1 or len(node.output) != 1:
logger.debug("fuse_constant_fold: node has more than one input or output")
return
# Check if parent node is Transpose node with only one input and one output
parent_node = self.model.match_parent(node, "Transpose", 0)
if parent_node is None:
logger.debug("fuse_constant_fold: failed to identify parent Transpose node")
return
if len(parent_node.input) != 1 or len(parent_node.output) != 1:
logger.debug("fuse_constant_fold: parent node has more than one input or output")
return
node_perm = node.attribute[0].ints
parent_node_perm = parent_node.attribute[0].ints
if node_perm != parent_node_perm:
logger.debug("fuse_constant_fold: Transpose node permutations aren't identical")
return
# For nodes that use output of child Transpose node as an input,
# replace that input with root_input
root_input = parent_node.input[0]
output_nodes = input_name_to_nodes[node.output[0]]
for output_node in output_nodes:
for i, input_ in enumerate(output_node.input):
if input_ == node.output[0]:
output_node.input[i] = root_input
# Add node to list of nodes to remove
self.nodes_to_remove.append(node)
self.nodes_to_remove.append(parent_node)
self.count += 1