Published on

Verification of PyTorch Graph Transformations using SMT Solvers

SMT solvers are quite powerful tools for proving satisfiability of various computational problems. As you can imagine, they are expensive and not always useful in production, but they are great for verifying the validity of a given problem, like an optimizing transformation for a machine learning library like PyTorch.

As an easier exercise, we can try and write a transformation on the PyTorch graph. This is made possible by the torch.fx module and its symbolic tracer that does a symbolic execution of the code and produces a graph for us. We can then write transformations for this graph. It will take in a torch.nn.Module, acquire a Graph from it, do some modifications, and return a new torch.nn.Module. Much like the fuser, we can try to find patterns of certain ops in order within the module and then do the corresponding folding. It is important to note that optimizations involving certain layers like ReLU and BatchNorm only work for models in inference mode (i.e. mode.eval()). We will write a fusion transformation for the Linear - BatchNorm pattern. In this case, we want to fold the BatchNorm parameters into the weight and bias of the preceding Linear layer. To generate and verify the fusion, we will simply traverse the graph using torch.fx, encoding the input and output of each layer alongside the specified behaviors of the layers in Z3. We will also write the encoding for the fused linear-bn layer and compare the output during interpretation with the output of the stacked Linear and BatchNorm layers within the model. We will check for a counterexample that negates the equivalence of these two. If a counterexample is not found, that means our transformation is correct. Here, correct essentially means that a transformation is valid if the original and transformed functions return the same value for all valid inputs. We will be implementing something similar to the optimization seen in this file. We are going to start off by creating a model consisting of linear layers and batch norms. We can try and write some trickier patterns, including nested sequential models.

from typing import Type, Dict, Any, Tuple, Iterable
import copy
import torch.fx as fx
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.mod = nn.BatchNorm2d(1)
    def forward(self, x):
        return self.mod(x)

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1 = nn.Linear(2, 5)
        self.bn_1 = nn.BatchNorm2d(5)
        self.linear_2 = nn.Linear(5, 2)
        self.nested = nn.Sequential(
            nn.BatchNorm2d(2),
            nn.Linear(2, 5),
        )
        self.wrapped = MyModel()

    def forward(self, x):
        x = self.linear_1(x)
        x = self.bn_1(x)
        x = self.linear_2(x)
        x = self.nested(x)
        x = self.wrapped(x)
        return x

model = M()
print(model.eval())

We can then invoke the symbolic tracer from FX. It performs “symbolic execution” of the Python code, enabling us to capture the semantics and operate on or transform them. From the torch.fx documentation,

Thus, we can get computational graph using the following:

traced_model = torch.fx.symbolic_trace(model)
print(traced_model.graph)

This fusion mostly relies on changing the weight and bias of the linear layer to contain the calculated parameters of the following batchnorm layer.

def fuse_linear_bn_eval(linear, bn):
    fused_linear = copy.deepcopy(linear)
    fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights(
        fused_linear.weight,
        fused_linear.bias,
        bn.running_mean,
        bn.running_var,
        bn.eps,
        bn.weight,
        bn.bias,
    )

    return fused_linear

def fuse_linear_bn_weights(linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
    if linear_b is None:
        linear_b = torch.zeros_like(bn_rm)
    bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps)

    fused_w = linear_w * bn_scale.unsqueeze(-1)
    fused_b = (linear_b - bn_rm) * bn_scale + bn_b

    return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter(
        fused_b, linear_b.requires_grad
    )

Now to write the graph traversal:

 def fuse(model: torch.nn.Module) -> torch.nn.Module:
    model = copy.deepcopy(model)
    fx_model: fx.GraphModule = fx.symbolic_trace(model)
    modules = dict(fx_model.named_modules())
    for node in fx_model.graph.nodes:
        if node.op != 'call_module':
            continue
        if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Linear:
            if len(node.args[0].users) > 1:
                continue
            linear = modules[node.args[0].target]
            bn = modules[node.target]
            fused_linear = fuse_linear_bn_eval(linear, bn)
            replace_node_module(node.args[0], modules, fused_linear)
            node.replace_all_uses_with(node.args[0])
            fx_model.graph.erase_node(node)
    fx_model.graph.lint()
    fx_model.recompile()
    return fx_model

Now, we can call the fused model and check the outputs are similar.

model = MyModel()
fused_model = fuse(model)
print(fused_model.print_readable())
inp = torch.randn(5, 2)
torch.testing.assert_allclose(fused_model(inp), model(inp))

At this point, we need to start thinking about our verification strategy. We are basically looking for a solution for this extremely complex inequality for a given input xx:

batchnorm(linear(x))fused(x)batchnorm(linear(x)) \neq fused(x)

If such a solution exists, it means our fusion transformation was buggy. We need to come up with a specification, which is a functionally correct implementation of the expected behavior for a layer. The easiest way to do this is simply write out the expected behavior of batchnorm, linear and the fused layers, and see if there exists a solution such that the outputs do not match. We will encode the actual input to out model as Z3 FPSorts as that is better for capturing PyTorch floats. We can define a helper functions to help construct these vector inputs:

def const_from_layer(layer_name, length):
    names = [layer_name + "." + str(i) for i in range(length)]
    return z3.Consts(names, z3.RealSort())

The length should be the number of features for a given layer. To illustrate this, we can try and encode a linear layer. We can see here that Z3 Sum does not work on FP Sorts, so we need to use a helper function.

def encode_linear(linear_layer, input_vec, out_vec):
    W, b = linear_layer.parameters()
    if input_vec is None:
        x = const_from_layer("linear_in", linear_layer.in_features)
    else:
        x = input_vec
    if out_vec is None:
        y = const_from_layer("linear_out", linear_layer.out_features)
    else:
        y = out_vec
    m, n = W.size()
    check_stmts = []
    for i in range(m):
        lhs = y[i]
        rhs = MySum([W[i, j].item() * x[j] for j in range(n)]) + b[i].item()
        check_stmts.append(z3.simplify(lhs != rhs))
    return check_stmts

We can insert this into our pass and check for a satisfying inequality. First, let’s try out the functionality using just the linear layer.

def fuse(model: torch.nn.Module) -> torch.nn.Module:
    model = copy.deepcopy(model)
    fx_model: fx.GraphModule = fx.symbolic_trace(model)
    modules = dict(fx_model.named_modules())
    for node in fx_model.graph.nodes:
        if node.op != 'call_module':
            continue
        if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Linear:
            if len(node.args[0].users) > 1:
                continue
            constraints = []

            # testing just linear for now
            in_vec = None
            out_vec = None
            linear = modules[node.args[0].target]
            constraints.extend(encode_linear(linear, in_vec, out_vec))

            bn = modules[node.target]
            fused_linear = fuse_linear_bn_eval(linear, bn)
            replace_node_module(node.args[0], modules, fused_linear)
            node.replace_all_uses_with(node.args[0])
            fx_model.graph.erase_node(node)
            z3.solve(constraints)
    fx_model.graph.lint()
    fx_model.recompile()
    return fx_model

def MySum(lst):
    return reduce(lambda a, b: a + b, lst, 0)

def const_from_layer(layer_name, length):
    names = [layer_name + "." + str(i) for i in range(length)]
    return z3.Consts(names, z3.RealSort())

def encode_linear(linear_layer, input_vec, out_vec):
    W, b = linear_layer.parameters()
    if input_vec is None:
        x = const_from_layer("linear_in", linear_layer.in_features)
    else:
        x = input_vec
    if out_vec is None:
        y = const_from_layer("linear_out", linear_layer.out_features)
    else:
        y = out_vec
    m, n = W.size()
    check_stmts = []
    for i in range(m):
        lhs = y[i]
        rhs = MySum([W[i, j].item() * x[j] for j in range(n)]) + b[i].item()
        check_stmts.append(z3.simplify(lhs == rhs))
    return check_stmts

model = MyModel()
fused_model = fuse(model)
fused_model.print_readable()
inp = torch.randn(5, 2)

We get:

[linear_in.0 = -69733518362045290/17666971683502197,
 linear_out.1 = 13205634763112294935260819537299/12018348084015100000000000000000,
 linear_in.1 = 0,
 linear_out.2 = 297249491480084958520107826916501/147224764029184975000000000000000,
 linear_out.4 = 0,
 linear_out.3 = -93197874156910903666779488059027/50477061952863420000000000000000,
 linear_out.0 = 539545963527161555742391006788239/883348584175109850000000000000000]
[linear_in.0 = 0,
 linear_in.2 = -16563417017459870/5729876831173897,
 linear_in.1 = 0,
 linear_in.3 = 0,
 linear_out.1 = 0,
 linear_in.4 = 0,
 linear_out.0 = -85248675375066700628948803714593/286493841558694850000000000000000]
[linear_in.0 = 0,
 linear_out.1 = 3968338966369629/10000000000000000,
 linear_in.1 = 0,
 linear_out.2 = -2651418447494507/10000000000000000,
 linear_out.4 = -3312041759490967/10000000000000000,
 linear_out.3 = 6879096627235413/10000000000000000,
 linear_out.0 = 3556443750858307/10000000000000000]
class MyModel(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        linear_1 = self.linear_1(x);  x = None
        linear_2 = self.linear_2(linear_1);  linear_1 = None
        nested_1 = getattr(self.nested, "1")(linear_2);  linear_2 = None
        return nested_1

Thus, we have a satisfying set of constraints for the linear layer. In the next post, we can extend this to our Linear-BatchNorm fusion.