.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "getting-started/tutorials/01_train_pc.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_getting-started_tutorials_01_train_pc.py: Train a PC ========== This tutorial demonstrates how to create a Hidden Chow-Liu Tree (https://arxiv.org/pdf/2106.02264.pdf) using :code:`pyjuice.structures` and train the model with mini-batch EM and full-batch EM. For simplicity, we use the MNIST dataset as an example. Note that the goal of this tutorial is just to quickly demonstrate the basic training pipeline using PyJuice without covering additional details such as ways to construct a PC, which will be covered in the following tutorials. .. GENERATED FROM PYTHON SOURCE LINES 10-13 .. code-block:: Python # sphinx_gallery_thumbnail_path = 'imgs/juice.png' .. GENERATED FROM PYTHON SOURCE LINES 14-16 Load the MNIST Dataset ---------------------- .. GENERATED FROM PYTHON SOURCE LINES 16-43 .. code-block:: Python import pyjuice as juice import torch import torchvision import time from torch.utils.data import TensorDataset, DataLoader import pyjuice.nodes.distributions as dists train_dataset = torchvision.datasets.MNIST(root = "../data", train = True, download = True) valid_dataset = torchvision.datasets.MNIST(root = "../data", train = False, download = True) train_data = train_dataset.data.reshape(60000, 28*28) valid_data = valid_dataset.data.reshape(10000, 28*28) train_loader = DataLoader( dataset = TensorDataset(train_data), batch_size = 512, shuffle = True, drop_last = True ) valid_loader = DataLoader( dataset = TensorDataset(valid_data), batch_size = 512, shuffle = False, drop_last = True ) .. GENERATED FROM PYTHON SOURCE LINES 44-46 Create the PC ------------- .. GENERATED FROM PYTHON SOURCE LINES 48-49 Let's create a HCLT PC with latent size 128. .. GENERATED FROM PYTHON SOURCE LINES 49-58 .. code-block:: Python device = torch.device("cuda:0") # The data is required to construct the backbone Chow-Liu Tree structure for the HCLT ns = juice.structures.HCLT( train_data.float().to(device), num_latents = 128 ) .. GENERATED FROM PYTHON SOURCE LINES 59-65 :code:`ns` is a Directed Acyclic Graph (DAG) representation of the PC. Specifically, we use :code:`pyjuice.nodes.InputNodes`, :code:`pyjuice.nodes.ProdNodes`, and :code:`pyjuice.nodes.SumNodes` to define vectors of input nodes, product nodes, and sum nodes, respectively. By also storing the topological structure of the node vectors (with pointers to the child node vectors), we create the PC as a DAG-based structure. :code:`ns` is also just a node vector defining the root node of the PC. While being user-friendly, the DAG-based representation is not amenable to efficient computation. Therefore, before doing any computation, we need to compile the PC with :code:`pyjuice.compile`, which creates a compact and equivalent representation of the PC. .. GENERATED FROM PYTHON SOURCE LINES 65-68 .. code-block:: Python pc = juice.compile(ns) .. GENERATED FROM PYTHON SOURCE LINES 69-71 The :code:`pc` is an instance of :code:`torch.nn.Module`. So we can safely assume it is just a neural network with the variable assignments :math:`\mathbf{x}` as input and its log-likelihood :math:`\log p(\mathbf{x})` as output. We proceed to move it to the GPU specified by :code:`device`. .. GENERATED FROM PYTHON SOURCE LINES 71-74 .. code-block:: Python pc.to(device) .. GENERATED FROM PYTHON SOURCE LINES 75-77 Train the PC ------------ .. GENERATED FROM PYTHON SOURCE LINES 79-80 We start by defining the optimizer and scheduler. .. GENERATED FROM PYTHON SOURCE LINES 80-89 .. code-block:: Python optimizer = juice.optim.CircuitOptimizer(pc, lr = 0.1, pseudocount = 0.1, method = "EM") scheduler = juice.optim.CircuitScheduler( optimizer, method = "multi_linear", lrs = [0.9, 0.1, 0.05], milestone_steps = [0, len(train_loader) * 100, len(train_loader) * 350] ) .. GENERATED FROM PYTHON SOURCE LINES 90-91 Optionally, we can leverage CUDA Graphs to hide the kernel launching overhead by doing a dry run. .. GENERATED FROM PYTHON SOURCE LINES 91-99 .. code-block:: Python for batch in train_loader: x = batch[0].to(device) lls = pc(x, record_cudagraph = True) lls.mean().backward() break .. GENERATED FROM PYTHON SOURCE LINES 100-101 We are now ready for the training. Below is an example training loop for mini-batch EM. .. GENERATED FROM PYTHON SOURCE LINES 101-137 .. code-block:: Python for epoch in range(1, 350+1): t0 = time.time() train_ll = 0.0 for batch in train_loader: x = batch[0].to(device) # Similar to PyTorch optimizers zeroling out the gradients, we zero out the parameter flows optimizer.zero_grad() # Forward pass lls = pc(x) # Backward pass lls.mean().backward() train_ll += lls.mean().detach().cpu().numpy().item() # Perform a mini-batch EM step optimizer.step() scheduler.step() train_ll /= len(train_loader) t1 = time.time() test_ll = 0.0 for batch in valid_loader: x = batch[0].to(pc.device) lls = pc(x) test_ll += lls.mean().detach().cpu().numpy().item() test_ll /= len(valid_loader) t2 = time.time() print(f"[Epoch {epoch}/{350}][train LL: {train_ll:.2f}; val LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; val forward {t2-t1:.2f}] ") .. GENERATED FROM PYTHON SOURCE LINES 138-139 Similarly, an example training loop for full-batch EM is given as follows. .. GENERATED FROM PYTHON SOURCE LINES 139-171 .. code-block:: Python for epoch in range(1, 1+1): t0 = time.time() # Manually zeroling out the flows pc.init_param_flows(flows_memory = 0.0) train_ll = 0.0 for batch in train_loader: x = batch[0].to(device) # We only run the forward and the backward pass, and accumulate the flows throughout the epoch lls = pc(x) lls.mean().backward() train_ll += lls.mean().detach().cpu().numpy().item() # Set step size to 1.0 for full-batch EM pc.mini_batch_em(step_size = 1.0, pseudocount = 0.01) train_ll /= len(train_loader) t1 = time.time() test_ll = 0.0 for batch in valid_loader: x = batch[0].to(pc.device) lls = pc(x) test_ll += lls.mean().detach().cpu().numpy().item() test_ll /= len(valid_loader) t2 = time.time() print(f"[Epoch {epoch}/{1}][train LL: {train_ll:.2f}; val LL: {test_ll:.2f}].....[train forward+backward+step {t1-t0:.2f}; val forward {t2-t1:.2f}] ") .. _sphx_glr_download_getting-started_tutorials_01_train_pc.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 01_train_pc.ipynb <01_train_pc.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 01_train_pc.py <01_train_pc.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 01_train_pc.zip <01_train_pc.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_