{ "cells": [ { "cell_type": "markdown", "source": [ "# Advanced features" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "## Introduction\n", "\n", "This guide covers an advanced example of (meta)programming ONNX in Spox. It includes usage of control flow (subgraphs) and and an alternate type (sequence).\n", "\n", "We will be using ORT for these examples, as its implementation is more complete than the reference." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "import warnings\n", "import logging\n", "import numpy as np\n", "import onnx\n", "import onnxruntime\n", "import spox._future\n", "from spox import argument, build, Tensor, Var\n", "import spox.opset.ai.onnx.v17 as op\n", "\n", "def const(value):\n", " return op.constant(value=np.array(value))\n", "\n", "def scalar(var: Var):\n", " return op.reshape(var, const(np.array([], int)))\n", "\n", "def run(model: onnx.ModelProto, **kwargs) -> list[np.ndarray]:\n", " options = onnxruntime.SessionOptions()\n", " options.log_severity_level = 3\n", " return onnxruntime.InferenceSession(model.SerializeToString(), options).run(\n", " None,\n", " {k: np.array(v) for k, v in kwargs.items()}\n", " )\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "logging.basicConfig(level=logging.DEBUG)\n", "spox._future.set_value_prop_backend(spox._future.ValuePropBackend.ONNXRUNTIME)" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "x = argument(Tensor(float, ()))" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "## Control flow\n", "\n", "The ONNX standards supports conditionally evaluated `subgraphs` (think 'functions', subroutines, subprograms...). They aren't evaluated by the runtime until required by the parent operator.\n", "\n", "In Spox, subgraphs have **first-class support** and are constructed by providing a _subgraph callback_. The callback is called with special subgraph argument nodes produced by Spox. Any existing `Var` objects (including from outer scopes) may be used in a subgraph. The Spox build system will appropriately place nodes in the ONNX output based on where their outputs are used.\n", "\n", "It is recommended to avoid side-effects in subgraph callbacks.\n", "\n", "_At the moment, value propagation is not run for operators with subgraphs to avoid unexpected build overhead._" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "### Conditional - If\n", "\n", "Conditionals are the simplest form of control flow in Spox, and may be computed with the `If` operator (available as `if_`). Only one of the branches is evaluated at runtime, depending on the value of the passed condition." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "# Compute relu of a scalar\n", "(relu_x,) = op.if_(\n", " op.less(x, const(0.)),\n", " then_branch=lambda: [const(0.)], # Branches have no arguments\n", " else_branch=lambda: [x] # And return an iterable of Vars\n", ")\n", "# relu_x represents the conditional's result (either from the then or else branch)" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "relu_model = build({\"x\": x}, {\"r\": relu_x})" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "data": { "text/plain": "[0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0]" }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[float(run(relu_model, x=float(i))[0]) for i in range(-3, 5)]" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "Ifs can also be composed arbitrarily:" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 6, "outputs": [], "source": [ "# Compute a piecewise constant function at a scalar point\n", "(pc_x,) = op.if_(\n", " op.less(x, const(0.)),\n", " then_branch=lambda: op.if_(\n", " op.less(x, const(-2.)),\n", " then_branch=lambda: [const(-3)],\n", " else_branch=lambda: [const(-1)],\n", " ),\n", " else_branch=lambda: op.if_(\n", " op.less(x, const(1.5)),\n", " then_branch=lambda: [const(2)],\n", " else_branch=lambda: [const(4)],\n", " ),\n", ")" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 7, "outputs": [], "source": [ "pc_model = build({\"x\": x}, {\"r\": pc_x})" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 8, "outputs": [ { "data": { "text/plain": "[-3.0, -3.0, -3.0, -1.0, -1.0, 2.0, 2.0, 4.0, 4.0, 4.0]" }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[float(run(pc_model, x=float(i))[0]) for i in range(-5, 5)]" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "### Fold - Loop" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "`Loop` is an operator implementing a for-loop-like control flow construct. It could also be seen as a hybrid combination of functional programming primitives like take-while, fold, and scan.\n", "\n", "This time the subgraph callback takes arguments: the current stop condition, iteration number, and the accumulators. It should return the stop condition, the updated accumulators, and scanned results.\n", "\n", "A basic example computing `0 + 1 + ... + x`:" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 9, "outputs": [], "source": [ "(sum_x,) = op.loop(\n", " op.add(op.cast(x, to=int), const(1)), # x+1 iterations\n", " v_initial=[const(0.)], # a := 0 at the start\n", " body=lambda i, _, a: [ # iteration (i), stopping (_), accumulator (a)\n", " const(True), # continue\n", " op.add(op.cast(i, to=float), a) # step is a := float(i) + a\n", " ]\n", ")\n", "# ONNX drops shape information for accumulators,\n", "# reshape into scalar explicitly\n", "sum_x = scalar(sum_x)" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 10, "outputs": [], "source": [ "sum_model = build({\"x\": x}, {\"r\": sum_x})" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 11, "outputs": [ { "data": { "text/plain": "[0.0, 1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0]" }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[float(run(sum_model, x=float(i))[0]) for i in range(8)]" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "A slightly more complex example involving a scan - computing factorials up to `x`:" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 12, "outputs": [], "source": [ "(fact_x, facts_x) = op.loop(\n", " op.add(op.cast(x, to=int), const(1)), # x+1 iterations\n", " v_initial=[const(1.)], # a := 0 at the start\n", " body=lambda i, _, a: [ # iteration (i), stopping (_), accumulator (a)\n", " const(True), # continue\n", " op.mul(\n", " op.add(op.cast(i, to=float), const(1.)),\n", " a\n", " ), # a := float(i) * a\n", " a # scan a\n", " ]\n", ")\n", "fact_x = scalar(fact_x)\n", "facts_x = op.reshape(facts_x, const([-1]))" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 13, "outputs": [], "source": [ "fact_model = build({\"x\": x}, {\"r\": fact_x, \"rs\": facts_x})" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 14, "outputs": [ { "data": { "text/plain": "[[array(1.), array([1.])],\n [array(2.), array([1., 1.])],\n [array(6.), array([1., 1., 2.])],\n [array(24.), array([1., 1., 2., 6.])],\n [array(120.), array([ 1., 1., 2., 6., 24.])]]" }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[run(fact_model, x=float(i)) for i in range(5)]" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "Keep in mind that more accumulators and scans can be present. ONNX resolves which value is what by counting positions - there are `2 + N` arguments (iteration, stop, `N` accumulators) and `1 + N + K` results (stop, `N` accumulators, `K` scans). In the above two examples we have `N = 1, K = 0` and `N = 1, K = 1` respectively." ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "## Sequences" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "Sequences are another type in the ONNX standard. In Spox they are treated on the same level as the common tensor.\n", "\n", "They may, however, cause type problems - due to their limited support.\n", "\n", "We'll go through some basic sequence methods and their behaviour based on value propagation:" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 15, "outputs": [ { "data": { "text/plain": "output_sequence of [int64] = [array(1), array(2), array(3), array(4)]>" }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "elems = op.sequence_construct([const(i) for i in [1, 2, 3, 4]])\n", "elems" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 16, "outputs": [ { "data": { "text/plain": "Sequence(elem_type=Tensor(dtype=int64, shape=())" }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "elems.type" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 17, "outputs": [ { "data": { "text/plain": "reshaped of int64 = 3>" }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scalar(op.sequence_at(elems, op.const(2)))" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 18, "outputs": [ { "data": { "text/plain": "output_sequence of [int64] = [array(1), array(2), array(3), array(4), array(5)]>" }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "op.sequence_insert(elems, const(5))" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 19, "outputs": [ { "data": { "text/plain": "output_sequence of [int64] = [array(1), array(7), array(2), array(3), array(4)]>" }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "op.sequence_insert(elems, const(7), const(1))" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "## Example - dynamic piecewise function\n", "\n", "We'll now go through a longer example combining the above. We'll store coefficients of linear functions in sequences, along with points defining the pieces of a piecewise linear function. A loop will find the piece corresponding to a query point.\n", "\n", "In this example, `coefficients`, `intercepts` and `pieces` are constants - however, they could be computed dynamically." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 20, "outputs": [], "source": [ "coefficients = op.sequence_construct([const(i) for i in [-1.0, 1.0, -2.0, 0.5]])\n", "intercepts = op.sequence_construct([const(i) for i in [-3.0, 3.0, 3.0, -4.5]])\n", "pieces = op.sequence_construct([const(i) for i in [-3.0, 0.0, 3.0]])" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 21, "outputs": [], "source": [ "(piece,) = op.loop(\n", " op.add(op.sequence_length(pieces), const(1)),\n", " v_initial=[const(0)],\n", " body=lambda i, _, _i: [\n", " # Stop when i >= len(pieces) or x > pieces[i]\n", " *op.if_(\n", " op.less(i, op.sequence_length(pieces)),\n", " then_branch=lambda: [op.greater(x, op.sequence_at(pieces, i))],\n", " else_branch=lambda: [op.const(False)]\n", " ),\n", " i # keep the i\n", " ]\n", ")" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 22, "outputs": [], "source": [ "result = op.add(\n", " op.mul(x, op.sequence_at(coefficients, piece)),\n", " op.sequence_at(intercepts, piece)\n", ")" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 23, "outputs": [], "source": [ "lp_model = build({\"x\": x}, {\"r\": result})" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 24, "outputs": [ { "data": { "text/plain": "[2.0, 1.0, 0.0, 1.0, 2.0, 3.0, 1.0, -1.0, -3.0, -2.5, -2.0]" }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[float(run(lp_model, x=float(i))[0]) for i in range(-5, 6)]" ], "metadata": { "collapsed": false } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }