Thursday, January 16, 2025

Posit AI Blog: torch 0.2.0

Posit AI Blog: torch 0.2.0

We are happy to announce that the version 0.2.0 of torch
just landed on CRAN.

This release includes many bug fixes and some nice new features
that we will present in this blog post. You can see the full changelog
in the NEWS.md file.

The features that we will discuss in detail are:

  • Initial support for JIT tracing
  • Multi-worker dataloaders
  • Print methods for nn_modules

Multi-worker dataloaders

dataloaders now respond to the num_workers argument and
will run the pre-processing in parallel workers.

For example, say we have the following dummy dataset that does
a long computation:

library(torch)
dat <- dataset(
  "mydataset",
  initialize = function(time, len = 10) {
    self$time <- time
    self$len <- len
  },
  .getitem = function(i) {
    Sys.sleep(self$time)
    torch_randn(1)
  },
  .length = function() {
    self$len
  }
)
ds <- dat(1)
system.time(ds[1])
   user  system elapsed 
  0.029   0.005   1.027 

We will now create two dataloaders, one that executes
sequentially and another executing in parallel.

seq_dl <- dataloader(ds, batch_size = 5)
par_dl <- dataloader(ds, batch_size = 5, num_workers = 2)

We can now compare the time it takes to process two batches sequentially to
the time it takes in parallel:

seq_it <- dataloader_make_iter(seq_dl)
par_it <- dataloader_make_iter(par_dl)

two_batches <- function(it) {
  dataloader_next(it)
  dataloader_next(it)
  "ok"
}

system.time(two_batches(seq_it))
system.time(two_batches(par_it))
   user  system elapsed 
  0.098   0.032  10.086 
   user  system elapsed 
  0.065   0.008   5.134 

Note that it is batches that are obtained in parallel, not individual observations. Like that, we will be able to support
datasets with variable batch sizes in the future.

Using multiple workers is not necessarily faster than serial execution because there’s a considerable overhead
when passing tensors from a worker to the main session as
well as when initializing the workers.

This feature is enabled by the powerful callr package
and works in all operating systems supported by torch. callr let’s
us create persistent R sessions, and thus, we only pay once the overhead of transferring potentially large dataset
objects to workers.

In the process of implementing this feature we have made
dataloaders behave like coro iterators.
This means that you can now use coro’s syntax
for looping through the dataloaders:

coro::loop(for(batch in par_dl) {
  print(batch$shape)
})
[1] 5 1
[1] 5 1

This is the first torch release including the multi-worker
dataloaders feature, and you might run into edge cases when
using it. Do let us know if you find any problems.

Initial JIT support

Programs that make use of the torch package are inevitably
R programs and thus, they always need an R installation in order
to execute.

As of version 0.2.0, torch allows users to JIT trace
torch R functions into TorchScript. JIT (Just in time) tracing will invoke
an R function with example inputs, record all operations that
occured when the function was run and return a script_function object
containing the TorchScript representation.

The nice thing about this is that TorchScript programs are easily
serializable, optimizable, and they can be loaded by another
program written in PyTorch or LibTorch without requiring any R
dependency.

Suppose you have the following R function that takes a tensor,
and does a matrix multiplication with a fixed weight matrix and
then adds a bias term:

w <- torch_randn(10, 1)
b <- torch_randn(1)
fn <- function(x) {
  a <- torch_mm(x, w)
  a + b
}

This function can be JIT-traced into TorchScript with jit_trace by passing the function and example inputs:

x <- torch_ones(2, 10)
tr_fn <- jit_trace(fn, x)
tr_fn(x)
torch_tensor
-0.6880
-0.6880
[ CPUFloatType{2,1} ]

Now all torch operations that happened when computing the result of
this function were traced and transformed into a graph:

graph(%0 : Float(2:10, 10:1, requires_grad=0, device=cpu)):
  %1 : Float(10:1, 1:1, requires_grad=0, device=cpu) = prim::Constant[value=-0.3532  0.6490 -0.9255  0.9452 -1.2844  0.3011  0.4590 -0.2026 -1.2983  1.5800 [ CPUFloatType{10,1} ]]()
  %2 : Float(2:1, 1:1, requires_grad=0, device=cpu) = aten::mm(%0, %1)
  %3 : Float(1:1, requires_grad=0, device=cpu) = prim::Constant[value={-0.558343}]()
  %4 : int = prim::Constant[value=1]()
  %5 : Float(2:1, 1:1, requires_grad=0, device=cpu) = aten::add(%2, %3, %4)
  return (%5)

The traced function can be serialized with jit_save:

jit_save(tr_fn, "linear.pt")

It can be reloaded in R with jit_load, but it can also be reloaded in Python
with torch.jit.load:

here. This will allow you also to take benefit of TorchScript to make your models
run faster!

Also note that tracing has some limitations, especially when your code has loops
or control flow statements that depend on tensor data. See ?jit_trace to
learn more.

New print method for nn_modules

In this release we have also improved the nn_module printing methods in order
to make it easier to understand what’s inside.

For example, if you create an instance of an nn_linear module you will
see:

An `nn_module` containing 11 parameters.

── Parameters ──────────────────────────────────────────────────────────────────
● weight: Float [1:1, 1:10]
● bias: Float [1:1]

You immediately see the total number of parameters in the module as well as
their names and shapes.

This also works for custom modules (possibly including sub-modules). For example:

my_module <- nn_module(
  initialize = function() {
    self$linear <- nn_linear(10, 1)
    self$param <- nn_parameter(torch_randn(5,1))
    self$buff <- nn_buffer(torch_randn(5))
  }
)
my_module()
An `nn_module` containing 16 parameters.

── Modules ─────────────────────────────────────────────────────────────────────
● linear: <nn_linear> #11 parameters

── Parameters ──────────────────────────────────────────────────────────────────
● param: Float [1:5, 1:1]

── Buffers ─────────────────────────────────────────────────────────────────────
● buff: Float [1:5]

We hope this makes it easier to understand nn_module objects.
We have also improved autocomplete support for nn_modules and we will now
show all sub-modules, parameters and buffers while you type.

torchaudio

torchaudio is an extension for torch developed by Athos Damiani (@athospd), providing audio loading, transformations, common architectures for signal processing, pre-trained weights and access to commonly used datasets. An almost literal translation from PyTorch’s Torchaudio library to R.

torchaudio is not yet on CRAN, but you can already try the development version
available here.

You can also visit the pkgdown website for examples and reference documentation.

Other features and bug fixes

Thanks to community contributions we have found and fixed many bugs in torch.
We have also added new features including:

You can see the full list of changes in the NEWS.md file.

Thanks very much for reading this blog post, and feel free to reach out on GitHub for help or discussions!

The photo used in this post preview is by Oleg Illarionov on Unsplash

Related Articles

Latest Articles