Circuit Evaluation Tutorial

Backends

Once we have created a circuit, we can start using it. KLay relies on a backend to perform the inference. Currently, the PyTorch and Jax backends are implemented.

We can turn the circuit into a PyTorch module as follows.

module = circuit.to_torch_module(semiring="log")

As the circuit is now a regular PyTorch module, we can move it to a different device.

module = module.to("cuda:0")

We can turn the circuit into a Jax function as follows.

func = circuit.to_jax_function(semiring="log")

The choice of semiring determines what operations the sum/product nodes perform. By default, this is the log semiring, which interprets sum nodes as the logsumexp operation and product nodes as addition. In the real semiring, product and sum nodes just compute the normal product and sum operations.

KLay doesn’t introduce a batch dimension by default. So use vmap to perform batched inference.

module = torch.vmap(module)
func = jax.vmap(func)

To achieve best runtime performance, it is advisable to use JIT compilation.

module = torch.compile(module, mode="reduce-overhead")
func = jax.jit(func)

Klay also supports probabilistic circuits, which have weights associated with the edges of sum nodes.

module2 = circuit.to_torch_module(semiring="real", probabilistic=True)
# Warning: not yet implemented!
func2 = circuit.to_jax_module(semiring="real", probabilistic=True)

Inference

The input to the circuit should be a tensor with as size the number of input literals. Note that when using the log semiring, the inputs are log-probabilities, while in the real or mpe semiring the inputs should be probabilities. In case you are using a probabilistic circuit, you should likely have some input distributions producing these (log-)probabilities prior to the circuit.

inputs = torch.tensor([...])
outputs = module(inputs)
inputs = jnp.array([...])
outputs = func(inputs)

Gradients are computed in the usual fashion.

outputs = func(inputs)
outputs.backward()
grad_func = jax.jit(jax.grad(func))
grad_func(inputs)

The inputs tensor must contain a weight for each positive literal. The weights of the negative literals follow from those. For example for the real semiring: if x is the weight of literal l, then 1 - x is the weight of the negative literal -l. To use other weights, you must provide a separate tensor containing a weight for each negative literal.

inputs = torch.tensor([...])
neg_inputs = torch.tensor([...])  # assumed 1-inputs otherwise
outputs = module(inputs, neg_inputs)
inputs = jnp.array([...])
neg_inputs = jnp.array([...])  # assumed 1-inputs otherwise
outputs = func(inputs, neg_inputs)