pyjuice.transformations.prune_by_score
- pyjuice.transformations.prune_by_score(root_nodes: CircuitNodes, key: str = '_scores', scores: Dict[CircuitNodes, ndarray | Tensor] | None = None, keep_frac: float | None = None, score_threshold: float | None = None, block_reduction: str = 'sum')
Prune sum-edge connections from a PC based on per-edge scores, returning a new, sparser PC. Edges are kept either by retaining the top keep_frac fraction of the highest-scoring edges, or by keeping all edges whose score is at least score_threshold (exactly one of the two must be given). Parameter-tied nodes and nodes without scores are left untouched.
- Parameters:
root_nodes (CircuitNodes) – the root of the PC to prune
key (str) – the attribute name under which per-node edge scores are stored on each node (used when scores is not given)
scores (Optional[Dict[CircuitNodes, Tensor]]) – an explicit mapping from nodes to their edge-score tensors; overrides key when provided
keep_frac (Optional[float]) – the fraction of highest-scoring edges to keep; mutually exclusive with score_threshold
score_threshold (Optional[float]) – the minimum score for an edge to be kept; mutually exclusive with keep_frac
block_reduction (str) – how per-edge scores are reduced over a node block when blocks have size > 1 (e.g., “sum”)
- Returns:
a new PC with the low-scoring edges removed
- Return type: