pyjuice.queries.query

pyjuice.queries.query(pc: TensorCircuit, inputs: Tensor, run_backward: bool = True, fw_input_fn: str | Callable | None = None, bk_input_fn: str | Callable | None = None, fw_output_fn: Callable | None = None, **kwargs)

A general-purpose entry point for running queries on a PC. It runs a forward pass (and optionally a backward pass) of the PC, optionally with custom input-layer functions, and is the common backend of marginal(), conditional(), and sample().

Parameters:
  • pc (TensorCircuit) – the input PC

  • inputs (torch.Tensor) – input tensor of size [B, num_vars], or a custom shape paired with fw_input_fn

  • run_backward (bool) – whether to run the backward pass after the forward pass

  • fw_input_fn (Optional[Union[str,Callable]]) – an optional custom function (or the name of an input-layer method) for the forward pass of input layers

  • bk_input_fn (Optional[Union[str,Callable]]) – an optional custom function (or the name of an input-layer method) for the backward pass of input layers

  • fw_output_fn (Optional[Callable]) – an optional function applied to pc right after the forward pass; if provided, its return value is returned immediately and no backward pass is run

Returns:

the log-likelihoods from the forward pass (when run_backward is False), the output of fw_output_fn (when provided), or None (after a backward pass)