Advanced features

Introduction

This guide covers an advanced example of (meta)programming ONNX in Spox. It includes usage of control flow (subgraphs) and and an alternate type (sequence).

We will be using ORT for these examples, as its implementation is more complete than the reference.

[1]:
import warnings
import logging
import numpy as np
import onnx
import onnxruntime
import spox._future
from spox import argument, build, Tensor, Var
import spox.opset.ai.onnx.v17 as op

def const(value):
    return op.constant(value=np.array(value))

def scalar(var: Var):
    return op.reshape(var, const(np.array([], int)))

def run(model: onnx.ModelProto, **kwargs) -> list[np.ndarray]:
    options = onnxruntime.SessionOptions()
    options.log_severity_level = 3
    return onnxruntime.InferenceSession(model.SerializeToString(), options).run(
        None,
        {k: np.array(v) for k, v in kwargs.items()}
    )

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.DEBUG)
spox._future.set_value_prop_backend(spox._future.ValuePropBackend.ONNXRUNTIME)
[2]:
x = argument(Tensor(float, ()))

Control flow

The ONNX standards supports conditionally evaluated subgraphs (think ‘functions’, subroutines, subprograms…). They aren’t evaluated by the runtime until required by the parent operator.

In Spox, subgraphs have first-class support and are constructed by providing a subgraph callback. The callback is called with special subgraph argument nodes produced by Spox. Any existing Var objects (including from outer scopes) may be used in a subgraph. The Spox build system will appropriately place nodes in the ONNX output based on where their outputs are used.

It is recommended to avoid side-effects in subgraph callbacks.

At the moment, value propagation is not run for operators with subgraphs to avoid unexpected build overhead.

Conditional - If

Conditionals are the simplest form of control flow in Spox, and may be computed with the If operator (available as if_). Only one of the branches is evaluated at runtime, depending on the value of the passed condition.

[3]:
# Compute relu of a scalar
(relu_x,) = op.if_(
    op.less(x, const(0.)),
    then_branch=lambda: [const(0.)],  # Branches have no arguments
    else_branch=lambda: [x]           # And return an iterable of Vars
)
# relu_x represents the conditional's result (either from the then or else branch)
[4]:
relu_model = build({"x": x}, {"r": relu_x})
[5]:
[float(run(relu_model, x=float(i))[0]) for i in range(-3, 5)]
[5]:
[0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0]

Ifs can also be composed arbitrarily:

[6]:
# Compute a piecewise constant function at a scalar point
(pc_x,) = op.if_(
    op.less(x, const(0.)),
    then_branch=lambda: op.if_(
        op.less(x, const(-2.)),
        then_branch=lambda: [const(-3)],
        else_branch=lambda: [const(-1)],
    ),
    else_branch=lambda: op.if_(
        op.less(x, const(1.5)),
        then_branch=lambda: [const(2)],
        else_branch=lambda: [const(4)],
    ),
)
[7]:
pc_model = build({"x": x}, {"r": pc_x})
[8]:
[float(run(pc_model, x=float(i))[0]) for i in range(-5, 5)]
[8]:
[-3.0, -3.0, -3.0, -1.0, -1.0, 2.0, 2.0, 4.0, 4.0, 4.0]

Fold - Loop

Loop is an operator implementing a for-loop-like control flow construct. It could also be seen as a hybrid combination of functional programming primitives like take-while, fold, and scan.

This time the subgraph callback takes arguments: the current stop condition, iteration number, and the accumulators. It should return the stop condition, the updated accumulators, and scanned results.

A basic example computing 0 + 1 + ... + x:

[9]:
(sum_x,) = op.loop(
    op.add(op.cast(x, to=int), const(1)),  # x+1 iterations
    v_initial=[const(0.)],  # a := 0 at the start
    body=lambda i, _, a: [  # iteration (i), stopping (_), accumulator (a)
        const(True), # continue
        op.add(op.cast(i, to=float), a)  # step is a := float(i) + a
    ]
)
# ONNX drops shape information for accumulators,
# reshape into scalar explicitly
sum_x = scalar(sum_x)
[10]:
sum_model = build({"x": x}, {"r": sum_x})
[11]:
[float(run(sum_model, x=float(i))[0]) for i in range(8)]
[11]:
[0.0, 1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0]

A slightly more complex example involving a scan - computing factorials up to x:

[12]:
(fact_x, facts_x) = op.loop(
    op.add(op.cast(x, to=int), const(1)),  # x+1 iterations
    v_initial=[const(1.)],  # a := 0 at the start
    body=lambda i, _, a: [  # iteration (i), stopping (_), accumulator (a)
        const(True), # continue
        op.mul(
            op.add(op.cast(i, to=float), const(1.)),
            a
        ),  # a := float(i) * a
        a  # scan a
    ]
)
fact_x = scalar(fact_x)
facts_x = op.reshape(facts_x, const([-1]))
[13]:
fact_model = build({"x": x}, {"r": fact_x, "rs": facts_x})
[14]:
[run(fact_model, x=float(i)) for i in range(5)]
[14]:
[[array(1.), array([1.])],
 [array(2.), array([1., 1.])],
 [array(6.), array([1., 1., 2.])],
 [array(24.), array([1., 1., 2., 6.])],
 [array(120.), array([ 1.,  1.,  2.,  6., 24.])]]

Keep in mind that more accumulators and scans can be present. ONNX resolves which value is what by counting positions - there are 2 + N arguments (iteration, stop, N accumulators) and 1 + N + K results (stop, N accumulators, K scans). In the above two examples we have N = 1, K = 0 and N = 1, K = 1 respectively.

Sequences

Sequences are another type in the ONNX standard. In Spox they are treated on the same level as the common tensor.

They may, however, cause type problems - due to their limited support.

We’ll go through some basic sequence methods and their behaviour based on value propagation:

[15]:
elems = op.sequence_construct([const(i) for i in [1, 2, 3, 4]])
elems
[15]:
<Var from ai.onnx@11::SequenceConstruct->output_sequence of [int64] = [array(1), array(2), array(3), array(4)]>
[16]:
elems.type
[16]:
Sequence(elem_type=Tensor(dtype=int64, shape=())
[17]:
scalar(op.sequence_at(elems, op.const(2)))
[17]:
<Var from ai.onnx@14::Reshape->reshaped of int64 = 3>
[18]:
op.sequence_insert(elems, const(5))
[18]:
<Var from ai.onnx@11::SequenceInsert->output_sequence of [int64] = [array(1), array(2), array(3), array(4), array(5)]>
[19]:
op.sequence_insert(elems, const(7), const(1))
[19]:
<Var from ai.onnx@11::SequenceInsert->output_sequence of [int64] = [array(1), array(7), array(2), array(3), array(4)]>

Example - dynamic piecewise function

We’ll now go through a longer example combining the above. We’ll store coefficients of linear functions in sequences, along with points defining the pieces of a piecewise linear function. A loop will find the piece corresponding to a query point.

In this example, coefficients, intercepts and pieces are constants - however, they could be computed dynamically.

[20]:
coefficients = op.sequence_construct([const(i) for i in [-1.0, 1.0, -2.0, 0.5]])
intercepts = op.sequence_construct([const(i) for i in [-3.0, 3.0, 3.0, -4.5]])
pieces = op.sequence_construct([const(i) for i in [-3.0, 0.0, 3.0]])
[21]:
(piece,) = op.loop(
    op.add(op.sequence_length(pieces), const(1)),
    v_initial=[const(0)],
    body=lambda i, _, _i: [
        # Stop when i >= len(pieces) or x > pieces[i]
        *op.if_(
            op.less(i, op.sequence_length(pieces)),
            then_branch=lambda: [op.greater(x, op.sequence_at(pieces, i))],
            else_branch=lambda: [op.const(False)]
        ),
        i  # keep the i
    ]
)
[22]:
result = op.add(
    op.mul(x, op.sequence_at(coefficients, piece)),
    op.sequence_at(intercepts, piece)
)
[23]:
lp_model = build({"x": x}, {"r": result})
[24]:
[float(run(lp_model, x=float(i))[0]) for i in range(-5, 6)]
[24]:
[2.0, 1.0, 0.0, 1.0, 2.0, 3.0, 1.0, -1.0, -3.0, -2.5, -2.0]