{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Inference" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Spox attempts to perform inference on operators immediately as they are constructed in Python.\n", "This includes two main mechanisms: type (and shape) inference, and value propagation.\n", "\n", "Both are done on a best-effort basis and primarily based on ONNX implementations.\n", "Some type inference and value propagation routines may be _patched_ in the generated opset. This is a Python implementation within Spox. This attempts to follow the standard, but may also be imperfect and have bugs (as can be the standard ONNX implementations).\n", "\n", "Inference mechanisms work effectively in various contexts. To make this work, Spox expects that type information will be carried in `Var`s through the entire graph, as it is constructed. This enables raising Python exceptions as early as possible when type inference fails, as well as improving debug and reducing the requirement of specifying redundant type information in Python.\n", "\n", "The general mechanism of this is the following: the single standard node is built into a _singleton_ (single-node) model as `onnx.ModelProto`. This is then passed into `onnx` routines. Afterwards, the information is extracted and converted back into Spox." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "tags": [] }, "outputs": [], "source": [ "import warnings\n", "import numpy as np\n", "import spox\n", "import spox.opset.ai.onnx.v17 as op" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Type inference" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Type and shape inference is run via `onnx.shape_inference.infer_shapes` on the singleton model. Types are converted to and from the ONNX representation internally. Some operators may have missing or incomplete type inference implementations (especially in ML operators of `ai.onnx.ml`), and may have a patch implemented in Spox.\n", "\n", "A `Var`'s type can be accessed with the `Var.type: Optional[Type]` attribute. It is however recommended to use the checked equivalents: `Var.unwrap_type() -> Type` or `Var.unwrap_tensor() -> Tensor`, which work better with type checkers.\n", "\n", "Patches can be currently found as an `infer_output_types` implementation in the respective Node class." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "tags": [] }, "outputs": [], "source": [ "x = spox.argument(spox.Tensor(float, ('N',)))\n", "y = spox.argument(spox.Tensor(float, ()))\n", "z = spox.argument(spox.Tensor(int, ('N', 'M')))\n", "w = spox.argument(spox.Tensor(int, (None,)))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "tags": [] }, "outputs": [ { "data": { "text/plain": "C of float64[N]>" }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Broadcasting of (N) and () into (N)\n", "op.add(x, y)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "tags": [] }, "outputs": [ { "data": { "text/plain": "output of str[N][M]>" }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Casting element type with a Cast\n", "op.cast(z, to=str)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "tags": [] }, "outputs": [ { "data": { "text/plain": "reshaped of int64[?]>" }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Reshape of a matrix into a vector\n", "op.reshape(z, op.constant(value_ints=[-1]))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "tags": [] }, "outputs": [ { "data": { "text/plain": "C of float64[N][N]>" }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Using a broadcast of (1, N) and (N, 1) into (N, N)\n", "op.add(\n", " op.unsqueeze(x, op.constant(value_ints=[0])),\n", " op.unsqueeze(x, op.constant(value_ints=[1]))\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "InferenceError: [ShapeInferenceError] (op_type:Add, node name: _this_): B has inconsistent type tensor(int64) -- for Add: inputs [A: float64[N], B: int64[N][M]]\n" ] } ], "source": [ "# Failing shape inference raises an exception\n", "try:\n", " print(op.add(x, z)) # mismatched types: float64, int64\n", "except Exception as e:\n", " print(f\"{type(e).__name__}: {e}\")\n", "else:\n", " assert False" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 8, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "InferenceWarning: Output type for variable reshaped of ai.onnx@14::Reshape was not concrete - ValueError: Tensor float64[...] does not specify the shape -- in ?.\n" ] } ], "source": [ "# Missing or unresolvable shape inference warns\n", "# w is a dynamic-size vector => can't determine reshaped rank\n", "warnings.filterwarnings(\"error\")\n", "try:\n", " print(op.reshape(x, w))\n", "except Exception as e:\n", " print(f\"{type(e).__name__}: {e}\")\n", "else:\n", " assert False\n", "warnings.resetwarnings()" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 9, "outputs": [ { "data": { "text/plain": "Tensor(dtype=float64, shape=('N',))" }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Access the type directly (might None if type inference failed)\n", "op.add(x, y).type" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 10, "outputs": [ { "data": { "text/plain": "Tensor(dtype=float64, shape=('N',))" }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Access the type, asserting it must be a tensor\n", "op.add(x, y).unwrap_tensor()" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "metadata": {}, "source": [ "Spox does not have a facility for extra type information in ``Var`` type hints. The ONNX type system, and particularly tensor shapes, is not really expressible in type hints. This may be reconsidered in the future if libraries like ``numpy`` start supporting similar features and the Python/ONNX ecosystem develops." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Value propagation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Value propagation in Spox is run via the `onnx.reference` module (added in 1.13) - the reference runtime implementation in Python. It replicates the _partial data propagation_ mechanism of type inference in ONNX, which is essentially constant folding.\n", "\n", "In Spox, a ``Var`` may have a constant value associated with it. If all input variables of a standard operator have a value, propagation will be attempted by running the singleton model through the reference implementation.\n", "\n", "The most common instance of value propagation is in the ``Reshape`` operator, where a constant target shape allows determining the resulting shape. If the target shape were not known, even the rank of the output shape could not be determined.\n", "\n", "Value propagation can also be thought of as **eager execution** mode within Spox, and is well-suited for experimenting with (standard) operators.\n", "\n", "Currently, there isn't a standard way of accessing the propagated value. It can be viewed when printed.\n", "Value propagation isn't usually patched as in most cases it is not critical to type inference. It is implemented by overriding the `propagate_values` method of Node classes." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "tags": [] }, "outputs": [ { "data": { "text/plain": "reshaped of float64[1][2][3]>" }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Trivial reshape example\n", "op.reshape(x, op.constant(value_ints=[1, 2, 3]))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "tags": [] }, "outputs": [ { "data": { "text/plain": "reshaped of float64[3][5]>" }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "s = op.add(\n", " op.mul(op.constant(value_ints=[1, 2]), op.constant(value_int=2)),\n", " op.constant(value_int=1)\n", ") # [1, 2] * 2 + 1 = [3, 5]\n", "# Reshape with a basic constant fold\n", "op.reshape(x, s)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Constant variable values can also be seen in the string representation.\n", "Currently, there isn't a stable way of accessing them programmatically - the internal field is `_value` and it can be converted to an ORT-like format with `_get_value()`. The representation isn't currently publicly specified." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "tags": [] }, "outputs": [], "source": [ "def const(xs):\n", " return op.constant(value=np.array(xs))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "tags": [] }, "outputs": [ { "data": { "text/plain": "C of int64[3] = [2 3 4]>" }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Trivial add\n", "op.add(\n", " const(1),\n", " const([1, 2, 3])\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "tags": [] }, "outputs": [ { "data": { "text/plain": "reshaped of float64[2][2] = [[1. 2.]\n [3. 4.]]>" }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Reshape\n", "mat = op.reshape(\n", " const([1., 2., 3., 4.]),\n", " const([2, 2])\n", ")\n", "mat" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "tags": [] }, "outputs": [ { "data": { "text/plain": "Y of float64[2][2] = [[ 7. 10.]\n [15. 22.]]>" }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Composing value propagation\n", "op.matmul(mat, mat)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "tags": [] }, "outputs": [ { "data": { "text/plain": "C of int64[3] = [2 3 4]>" }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Unstable! Programmatic access\n", "v = op.add(\n", " const(1),\n", " const([1, 2, 3])\n", ")\n", "np.testing.assert_allclose(v._get_value(), np.array([2, 3, 4]))\n", "v" ] }, { "cell_type": "markdown", "source": [ "### Testing an ML operator" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "ML operators (`ai.onnx.ml`) can be used similarly to standard ones (`ai.onnx`). Spox ships with pre-generated ML domain opset modules which you can find under `spox.opset.ai.onnx.ml.v3`.\n", "\n", "With value propagation and a supporting backend, you can test an ML operator and run it on input without having to leave Spox:" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 18, "outputs": [], "source": [ "import spox.opset.ai.onnx.ml.v3 as ml\n", "import spox._future\n", "# Currently, you need ORT to run ML operators\n", "spox._future.set_value_prop_backend(spox._future.ValuePropBackend.ONNXRUNTIME)" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 19, "outputs": [ { "data": { "text/plain": "Y of float32[3][1] = [[ 4.]\n [ 7.]\n [10.]]>" }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ml.linear_regressor(\n", " const(np.array([[1], [2], [3]], dtype=np.float32)),\n", " coefficients=[3],\n", " intercepts=[1]\n", ")" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 20, "outputs": [ { "data": { "text/plain": "Y of str[3] = ['?' 'one' 'two']>" }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ml.label_encoder(\n", " const([0, 1, 2]),\n", " keys_int64s=[1, 2],\n", " values_strings=[\"one\", \"two\"],\n", " default_string=\"?\"\n", ")" ], "metadata": { "collapsed": false } } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 4 }