pyjuice.queries.query
- pyjuice.queries.query(pc: TensorCircuit, inputs: Tensor, run_backward: bool = True, fw_input_fn: str | Callable | None = None, bk_input_fn: str | Callable | None = None, fw_output_fn: Callable | None = None, **kwargs)
A general-purpose entry point for running queries on a PC. It runs a forward pass (and optionally a backward pass) of the PC, optionally with custom input-layer functions, and is the common backend of
marginal(),conditional(), andsample().- Parameters:
pc (TensorCircuit) – the input PC
inputs (torch.Tensor) – input tensor of size [B, num_vars], or a custom shape paired with fw_input_fn
run_backward (bool) – whether to run the backward pass after the forward pass
fw_input_fn (Optional[Union[str,Callable]]) – an optional custom function (or the name of an input-layer method) for the forward pass of input layers
bk_input_fn (Optional[Union[str,Callable]]) – an optional custom function (or the name of an input-layer method) for the backward pass of input layers
fw_output_fn (Optional[Callable]) – an optional function applied to pc right after the forward pass; if provided, its return value is returned immediately and no backward pass is run
- Returns:
the log-likelihoods from the forward pass (when run_backward is False), the output of fw_output_fn (when provided), or None (after a backward pass)