Advanced features
Introduction
This guide covers an advanced example of (meta)programming ONNX in Spox. It includes usage of control flow (subgraphs) and and an alternate type (sequence).
We will be using ORT for these examples, as its implementation is more complete than the reference.
[1]:
import warnings
import logging
import numpy as np
import onnx
import onnxruntime
import spox._future
from spox import argument, build, Tensor, Var
import spox.opset.ai.onnx.v17 as op
def const(value):
return op.constant(value=np.array(value))
def scalar(var: Var):
return op.reshape(var, const(np.array([], int)))
def run(model: onnx.ModelProto, **kwargs) -> list[np.ndarray]:
options = onnxruntime.SessionOptions()
options.log_severity_level = 3
return onnxruntime.InferenceSession(model.SerializeToString(), options).run(
None,
{k: np.array(v) for k, v in kwargs.items()}
)
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.DEBUG)
spox._future.set_value_prop_backend(spox._future.ValuePropBackend.ONNXRUNTIME)
[2]:
x = argument(Tensor(float, ()))
Control flow
The ONNX standards supports conditionally evaluated subgraphs
(think ‘functions’, subroutines, subprograms…). They aren’t evaluated by the runtime until required by the parent operator.
In Spox, subgraphs have first-class support and are constructed by providing a subgraph callback. The callback is called with special subgraph argument nodes produced by Spox. Any existing Var
objects (including from outer scopes) may be used in a subgraph. The Spox build system will appropriately place nodes in the ONNX output based on where their outputs are used.
It is recommended to avoid side-effects in subgraph callbacks.
At the moment, value propagation is not run for operators with subgraphs to avoid unexpected build overhead.
Conditional - If
Conditionals are the simplest form of control flow in Spox, and may be computed with the If
operator (available as if_
). Only one of the branches is evaluated at runtime, depending on the value of the passed condition.
[3]:
# Compute relu of a scalar
(relu_x,) = op.if_(
op.less(x, const(0.)),
then_branch=lambda: [const(0.)], # Branches have no arguments
else_branch=lambda: [x] # And return an iterable of Vars
)
# relu_x represents the conditional's result (either from the then or else branch)
[4]:
relu_model = build({"x": x}, {"r": relu_x})
[5]:
[float(run(relu_model, x=float(i))[0]) for i in range(-3, 5)]
[5]:
[0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0]
Ifs can also be composed arbitrarily:
[6]:
# Compute a piecewise constant function at a scalar point
(pc_x,) = op.if_(
op.less(x, const(0.)),
then_branch=lambda: op.if_(
op.less(x, const(-2.)),
then_branch=lambda: [const(-3)],
else_branch=lambda: [const(-1)],
),
else_branch=lambda: op.if_(
op.less(x, const(1.5)),
then_branch=lambda: [const(2)],
else_branch=lambda: [const(4)],
),
)
[7]:
pc_model = build({"x": x}, {"r": pc_x})
[8]:
[float(run(pc_model, x=float(i))[0]) for i in range(-5, 5)]
[8]:
[-3.0, -3.0, -3.0, -1.0, -1.0, 2.0, 2.0, 4.0, 4.0, 4.0]
Fold - Loop
Loop
is an operator implementing a for-loop-like control flow construct. It could also be seen as a hybrid combination of functional programming primitives like take-while, fold, and scan.
This time the subgraph callback takes arguments: the current stop condition, iteration number, and the accumulators. It should return the stop condition, the updated accumulators, and scanned results.
A basic example computing 0 + 1 + ... + x
:
[9]:
(sum_x,) = op.loop(
op.add(op.cast(x, to=int), const(1)), # x+1 iterations
v_initial=[const(0.)], # a := 0 at the start
body=lambda i, _, a: [ # iteration (i), stopping (_), accumulator (a)
const(True), # continue
op.add(op.cast(i, to=float), a) # step is a := float(i) + a
]
)
# ONNX drops shape information for accumulators,
# reshape into scalar explicitly
sum_x = scalar(sum_x)
[10]:
sum_model = build({"x": x}, {"r": sum_x})
[11]:
[float(run(sum_model, x=float(i))[0]) for i in range(8)]
[11]:
[0.0, 1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0]
A slightly more complex example involving a scan - computing factorials up to x
:
[12]:
(fact_x, facts_x) = op.loop(
op.add(op.cast(x, to=int), const(1)), # x+1 iterations
v_initial=[const(1.)], # a := 0 at the start
body=lambda i, _, a: [ # iteration (i), stopping (_), accumulator (a)
const(True), # continue
op.mul(
op.add(op.cast(i, to=float), const(1.)),
a
), # a := float(i) * a
a # scan a
]
)
fact_x = scalar(fact_x)
facts_x = op.reshape(facts_x, const([-1]))
[13]:
fact_model = build({"x": x}, {"r": fact_x, "rs": facts_x})
[14]:
[run(fact_model, x=float(i)) for i in range(5)]
[14]:
[[array(1.), array([1.])],
[array(2.), array([1., 1.])],
[array(6.), array([1., 1., 2.])],
[array(24.), array([1., 1., 2., 6.])],
[array(120.), array([ 1., 1., 2., 6., 24.])]]
Keep in mind that more accumulators and scans can be present. ONNX resolves which value is what by counting positions - there are 2 + N
arguments (iteration, stop, N
accumulators) and 1 + N + K
results (stop, N
accumulators, K
scans). In the above two examples we have N = 1, K = 0
and N = 1, K = 1
respectively.
Sequences
Sequences are another type in the ONNX standard. In Spox they are treated on the same level as the common tensor.
They may, however, cause type problems - due to their limited support.
We’ll go through some basic sequence methods and their behaviour based on value propagation:
[15]:
elems = op.sequence_construct([const(i) for i in [1, 2, 3, 4]])
elems
[15]:
<Var from ai.onnx@11::SequenceConstruct->output_sequence of [int64] = [array(1), array(2), array(3), array(4)]>
[16]:
elems.type
[16]:
Sequence(elem_type=Tensor(dtype=int64, shape=())
[17]:
scalar(op.sequence_at(elems, op.const(2)))
[17]:
<Var from ai.onnx@14::Reshape->reshaped of int64 = 3>
[18]:
op.sequence_insert(elems, const(5))
[18]:
<Var from ai.onnx@11::SequenceInsert->output_sequence of [int64] = [array(1), array(2), array(3), array(4), array(5)]>
[19]:
op.sequence_insert(elems, const(7), const(1))
[19]:
<Var from ai.onnx@11::SequenceInsert->output_sequence of [int64] = [array(1), array(7), array(2), array(3), array(4)]>
Example - dynamic piecewise function
We’ll now go through a longer example combining the above. We’ll store coefficients of linear functions in sequences, along with points defining the pieces of a piecewise linear function. A loop will find the piece corresponding to a query point.
In this example, coefficients
, intercepts
and pieces
are constants - however, they could be computed dynamically.
[20]:
coefficients = op.sequence_construct([const(i) for i in [-1.0, 1.0, -2.0, 0.5]])
intercepts = op.sequence_construct([const(i) for i in [-3.0, 3.0, 3.0, -4.5]])
pieces = op.sequence_construct([const(i) for i in [-3.0, 0.0, 3.0]])
[21]:
(piece,) = op.loop(
op.add(op.sequence_length(pieces), const(1)),
v_initial=[const(0)],
body=lambda i, _, _i: [
# Stop when i >= len(pieces) or x > pieces[i]
*op.if_(
op.less(i, op.sequence_length(pieces)),
then_branch=lambda: [op.greater(x, op.sequence_at(pieces, i))],
else_branch=lambda: [op.const(False)]
),
i # keep the i
]
)
[22]:
result = op.add(
op.mul(x, op.sequence_at(coefficients, piece)),
op.sequence_at(intercepts, piece)
)
[23]:
lp_model = build({"x": x}, {"r": result})
[24]:
[float(run(lp_model, x=float(i))[0]) for i in range(-5, 6)]
[24]:
[2.0, 1.0, 0.0, 1.0, 2.0, 3.0, 1.0, -1.0, -3.0, -2.5, -2.0]