pyjuice.optim
PyTorch-style optimizer and learning-rate scheduler for training PCs. They mirror the
torch.optim API, so a PC training loop reads like a standard PyTorch loop
(zero_grad / step), while optionally also updating attached neural-network parameters.
- class pyjuice.optim.CircuitOptimizer(pc: TensorCircuit, base_optimizer: Optimizer | None = None, method: str = 'EM', lr: float = 0.1, pseudocount: float = 0.1, **kwargs)
A PyTorch-style optimizer for PCs that wraps PyJuice’s parameter-update routines (e.g., EM).
It mirrors the
torch.optim.OptimizerAPI (zero_grad(),step(),state_dict(),load_state_dict()), so a training loop looks the same as a standard PyTorch loop. An optional base_optimizer can be supplied to additionally update any non-PC (e.g., neural network) parameters in the same step.- Parameters:
pc (TensorCircuit) – the PC to optimize
base_optimizer (Optional[torch.optim.Optimizer]) – an optional PyTorch optimizer for non-PC parameters, stepped alongside the PC update
method (str) – the parameter-update method; one of “EM”, “Viterbi”, or “GeneralEM”
lr (float) – the step size (learning rate) of the PC parameter update
pseudocount (float) – the Laplace smoothing pseudocount added during the update
- zero_grad()
- step(closure=None)
- state_dict()
- class pyjuice.optim.CircuitScheduler(optimizer: CircuitOptimizer, base_scheduler: LRScheduler | None = None, method: str = 'constant', **kwargs)
A learning-rate scheduler for a
CircuitOptimizer, analogous to the schedulers intorch.optim.lr_scheduler. Callingstep()updates the optimizer’s step size according to the chosen schedule.- Parameters:
optimizer (CircuitOptimizer) – the circuit optimizer whose step size is being scheduled
base_scheduler (Optional[torch.optim.lr_scheduler.LRScheduler]) – an optional PyTorch scheduler stepped alongside, for an attached base_optimizer
method (str) – the schedule type; one of “constant” or “multi_linear” (piecewise-linear between milestones)
For the “multi_linear” schedule, pass lrs (the learning rates at each milestone) and milestone_steps (the step indices of the milestones) as keyword arguments.
- step()