Converters

Spox does not directly offer any ONNX converters (utilities for translating ML models into ONNX), but it can be easily used to implement a converter protocol. We’ll go over an example way of achieving this. In general, it is easiest to convert operations from libraries like numpy or deep learning frameworks, since ONNX follows similar principles.

[1]:
from typing import Dict
import onnx
import onnxruntime
import numpy as np
from spox import argument, build, Tensor, Var
import spox.opset.ai.onnx.v17 as op


def run(model: onnx.ModelProto, **kwargs) -> list[np.ndarray]:
    return onnxruntime.InferenceSession(model.SerializeToString()).run(
        None,
        {k: np.array(v) for k, v in kwargs.items()}
    )

Functions

We’ll start with simple conversion of Python functions on numpy.arrays into Spox equivalents on Vars (of tensors).

Let’s define functions computing means on a pair of tensors:

[2]:
def arithmetic_mean(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    return np.divide(np.add(a, b), 2)


def geometric_mean(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    return np.sqrt(np.multiply(a, b))


def harmonic_mean(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    return np.divide(2., np.add(
        np.reciprocal(a),
        np.reciprocal(b)
    ))

We can now define equivalents in Spox. We’ll follow a contract stating that arguments and results of numpy.ndarray become Var, which is expected to be a tensor.

[3]:
def spox_arithmetic_mean(a: Var, b: Var) -> Var:
    return op.div(op.add(a, b), op.constant(value_float=2.))


def spox_geometric_mean(a: Var, b: Var) -> Var:
    return op.sqrt(op.mul(a, b))


def spox_harmonic_mean(a: Var, b: Var) -> Var:
    return op.div(op.constant(value_float=2.), op.add(
        op.reciprocal(a),
        op.reciprocal(b)
    ))

Estimators

Let’s also consider an sklearn-like estimator on ‘dataframes’ (dictionaries of arrays).

[4]:
class PairwiseMeans:
    kind: str  # 'arithmetic', 'geometric', or 'harmonic'
    first: str
    second: str  # name of first and second 'column' to find the mean of

    def __init__(self, kind: str, first: str, second: str):
        self.kind = kind
        self.first = first
        self.second = second

    def predict(self, data: Dict[str, np.ndarray]) -> np.ndarray:
        means = {
            'arithmetic': arithmetic_mean,
            'geometric': geometric_mean,
            'harmonic': harmonic_mean,
        }
        return means[self.kind](data[self.first], data[self.second])

The equivalent in Spox could be a class ‘decorating’ a PairwiseMeans instance - consuming it and implementing the same interface, but using Vars instead of numpy.ndarray.

[5]:
class SpoxPairwiseMeans:
    estimator: PairwiseMeans

    def __init__(self, estimator: PairwiseMeans):
        self.estimator = estimator

    def predict(self, data: Dict[str, Var]) -> Var:
        means = {
            'arithmetic': spox_arithmetic_mean,
            'geometric': spox_geometric_mean,
            'harmonic': spox_harmonic_mean,
        }
        return means[self.estimator.kind](data[self.estimator.first], data[self.estimator.second])

Converter

To provide a simple API for conversion, we can define a convert function handling the possible conversions. The mapping could be defined with e.g. a dictionary to make it more dynamically extensible.

[6]:
def convert(obj):
    if obj is arithmetic_mean:
        return spox_arithmetic_mean
    elif obj is geometric_mean:
        return spox_geometric_mean
    elif obj is harmonic_mean:
        return spox_harmonic_mean
    elif type(obj) is PairwiseMeans:
        return SpoxPairwiseMeans(obj)
    raise ValueError(f"No converter for: {obj}")

To build a model, we have to construct the arguments and pass them with the result to spox.build. This could be abstracted away with a usage of inspect.signature and by extracting the input types from example input data, but we’ll not consider this here.

[7]:
pairwise_means = PairwiseMeans('harmonic', 'x', 'z')
[8]:
vec = Tensor(np.float32, ('N',))
x, y, z = argument(vec), argument(vec), argument(vec)


def simple_convert_build(fun):
    return build({'x': x, 'y': y}, {'r': convert(fun)(x, y)})


arithmetic_mean_model = simple_convert_build(arithmetic_mean)
geometric_mean_model = simple_convert_build(geometric_mean)
harmonic_mean_model = simple_convert_build(harmonic_mean)
pairwise_means_model = build(
    {'x': x, 'y': y, 'z': z},
    {'r': convert(pairwise_means).predict({'x': x, 'y': y, 'z': z})}
)

Checking equivalence

We can now test equivalence by running the onnxruntime with the previously defined run utility.

[9]:
x0 = np.array([1, 2, 3], dtype=np.float32)
y0 = np.array([4, 6, 5], dtype=np.float32)
z0 = np.array([-2, -1, -0.5], dtype=np.float32)

An example run looks like this. Note that this is not going through Spox, as at this point arithmetic_mean_model is an onnx.ModelProto.

[10]:
arithmetic_mean(x0, y0), run(arithmetic_mean_model, x=x0, y=y0)[0]
[10]:
(array([2.5, 4. , 4. ], dtype=float32), array([2.5, 4. , 4. ], dtype=float32))
[11]:
tests = [
    (arithmetic_mean, arithmetic_mean_model),
    (geometric_mean, geometric_mean_model),
    (harmonic_mean, harmonic_mean_model),
]
for py_function, onnx_model in tests:
    actual = run(onnx_model, x=x0, y=y0)[0]
    desired = py_function(x0, y0)
    print(actual, desired)
    np.testing.assert_allclose(actual, desired)
[2.5 4.  4. ] [2.5 4.  4. ]
[2.        3.4641016 3.8729835] [2.        3.4641016 3.8729835]
[1.6       3.        3.7499998] [1.6       3.        3.7499998]
[12]:
actual = run(pairwise_means_model, x=x0, y=y0, z=z0)[0]
desired = pairwise_means.predict({'x': x0, 'y': y0, 'z': z0})
print(actual, desired)
np.testing.assert_allclose(actual, desired)
[ 4.  -4.  -1.2] [ 4.  -4.  -1.2]
[ ]: