pyjuice.structures.PD
- pyjuice.structures.PD(data_shape: ~typing.Tuple, num_latents: int, split_intervals: int | ~typing.Tuple[int] | None = None, split_points: ~typing.Sequence[~typing.Sequence[int]] | None = None, max_split_depth: int | None = None, max_prod_block_conns: int = 4, structure_type: str = 'sum_dominated', input_ns_fn: ~typing.Callable | None = None, input_dist: ~pyjuice.nodes.distributions.distributions.Distribution | None = None, input_node_type: ~typing.Type[~pyjuice.nodes.distributions.distributions.Distribution] = <class 'pyjuice.nodes.distributions.categorical.Categorical'>, input_node_params: ~typing.Dict = {'num_cats': 256}, tie_homogeneous_params: bool = False, block_size: int | None = None)
Generate PCs with the PD structure (https://arxiv.org/pdf/1202.3732.pdf).
- Parameters:
data_shape (Tuple) – shape of the data (e.g., [H, W, 3] for images and [S] for sequences)
num_latents (int) – size of the latent space
split_intervals (Optional[Union[int, Tuple[int]]]) – a sequence of integers specifying the interval between split points in every dimension; either this or split_points needs to be specified
split_points (Optional[Sequence[Sequence[int]]]) – a sequence of split points in each dimension; either this or split_intervals needs to be specified
max_split_depth (Optional[int]) – maximum depth of the splits
max_prod_block_conns (int) – the maximum number of product nodes connected to every sum node
structure_type (str) – whether to reuse sum nodes; no reuse: “sum_dominated”, reuse: “prod_dominated”
input_dist (Distribution) – input distribution
tie_homogeneous_params (bool) – whether to tie parameters of sum/input nodes with compatible structures
block_size (int) – block size