We are happy to announce that the version 0.2.0 of torch
just landed on CRAN.
This release includes many bug fixes and some nice new features
that we will present in this blog post. You can see the full changelog
in the NEWS.md file.
The features that we will discuss in detail are:
- Initial support for JIT tracing
- Multi-worker dataloaders
- Print methods for
nn_modules
Multi-worker dataloaders
dataloaders
now respond to the num_workers
argument and
will run the pre-processing in parallel workers.
For example, say we have the following dummy dataset that does
a long computation:
library(torch)
dat <- dataset(
"mydataset",
initialize = function(time, len = 10) {
self$time <- time
self$len <- len
},
.getitem = function(i) {
Sys.sleep(self$time)
torch_randn(1)
},
.length = function() {
self$len
}
)
ds <- dat(1)
system.time(ds[1])
user system elapsed
0.029 0.005 1.027
We will now create two dataloaders, one that executes
sequentially and another executing in parallel.
seq_dl <- dataloader(ds, batch_size = 5)
par_dl <- dataloader(ds, batch_size = 5, num_workers = 2)
We can now compare the time it takes to process two batches sequentially to
the time it takes in parallel:
seq_it <- dataloader_make_iter(seq_dl)
par_it <- dataloader_make_iter(par_dl)
two_batches <- function(it) {
dataloader_next(it)
dataloader_next(it)
"ok"
}
system.time(two_batches(seq_it))
system.time(two_batches(par_it))
user system elapsed
0.098 0.032 10.086
user system elapsed
0.065 0.008 5.134
Note that it is batches that are obtained in parallel, not individual observations. Like that, we will be able to support
datasets with variable batch sizes in the future.
Using multiple workers is not necessarily faster than serial execution because there’s a considerable overhead
when passing tensors from a worker to the main session as
well as when initializing the workers.
This feature is enabled by the powerful callr
package
and works in all operating systems supported by torch
. callr
let’s
us create persistent R sessions, and thus, we only pay once the overhead of transferring potentially large dataset
objects to workers.
In the process of implementing this feature we have made
dataloaders behave like coro
iterators.
This means that you can now use coro
’s syntax
for looping through the dataloaders:
coro::loop(for(batch in par_dl) {
print(batch$shape)
})
[1] 5 1
[1] 5 1
This is the first torch
release including the multi-worker
dataloaders feature, and you might run into edge cases when
using it. Do let us know if you find any problems.
Initial JIT support
Programs that make use of the torch
package are inevitably
R programs and thus, they always need an R installation in order
to execute.
As of version 0.2.0, torch
allows users to JIT trace
torch
R functions into TorchScript. JIT (Just in time) tracing will invoke
an R function with example inputs, record all operations that
occured when the function was run and return a script_function
object
containing the TorchScript representation.
The nice thing about this is that TorchScript programs are easily
serializable, optimizable, and they can be loaded by another
program written in PyTorch or LibTorch without requiring any R
dependency.
Suppose you have the following R function that takes a tensor,
and does a matrix multiplication with a fixed weight matrix and
then adds a bias term:
w <- torch_randn(10, 1)
b <- torch_randn(1)
fn <- function(x) {
a <- torch_mm(x, w)
a + b
}
This function can be JIT-traced into TorchScript with jit_trace
by passing the function and example inputs:
x <- torch_ones(2, 10)
tr_fn <- jit_trace(fn, x)
tr_fn(x)
torch_tensor
-0.6880
-0.6880
[ CPUFloatType{2,1} ]
Now all torch
operations that happened when computing the result of
this function were traced and transformed into a graph:
graph(%0 : Float(2:10, 10:1, requires_grad=0, device=cpu)):
%1 : Float(10:1, 1:1, requires_grad=0, device=cpu) = prim::Constant[value=-0.3532 0.6490 -0.9255 0.9452 -1.2844 0.3011 0.4590 -0.2026 -1.2983 1.5800 [ CPUFloatType{10,1} ]]()
%2 : Float(2:1, 1:1, requires_grad=0, device=cpu) = aten::mm(%0, %1)
%3 : Float(1:1, requires_grad=0, device=cpu) = prim::Constant[value={-0.558343}]()
%4 : int = prim::Constant[value=1]()
%5 : Float(2:1, 1:1, requires_grad=0, device=cpu) = aten::add(%2, %3, %4)
return (%5)
The traced function can be serialized with jit_save
:
jit_save(tr_fn, "linear.pt")
It can be reloaded in R with jit_load
, but it can also be reloaded in Python
with torch.jit.load
:
import torch
= torch.jit.load("linear.pt")
fn 2, 10)) fn(torch.ones(
tensor([[-0.6880],
[-0.6880]])
How cool is that?!
This is just the initial support for JIT in R. We will continue developing
this. Specifically, in the next version of torch
we plan to support tracing nn_modules
directly. Currently, you need to detach all parameters before
tracing them; see an example here. This will allow you also to take benefit of TorchScript to make your models
run faster!
Also note that tracing has some limitations, especially when your code has loops
or control flow statements that depend on tensor data. See ?jit_trace
to
learn more.
New print method for nn_modules
In this release we have also improved the nn_module
printing methods in order
to make it easier to understand what’s inside.
For example, if you create an instance of an nn_linear
module you will
see:
An `nn_module` containing 11 parameters.
── Parameters ──────────────────────────────────────────────────────────────────
● weight: Float [1:1, 1:10]
● bias: Float [1:1]
You immediately see the total number of parameters in the module as well as
their names and shapes.
This also works for custom modules (possibly including sub-modules). For example:
my_module <- nn_module(
initialize = function() {
self$linear <- nn_linear(10, 1)
self$param <- nn_parameter(torch_randn(5,1))
self$buff <- nn_buffer(torch_randn(5))
}
)
my_module()
An `nn_module` containing 16 parameters.
── Modules ─────────────────────────────────────────────────────────────────────
● linear: <nn_linear> #11 parameters
── Parameters ──────────────────────────────────────────────────────────────────
● param: Float [1:5, 1:1]
── Buffers ─────────────────────────────────────────────────────────────────────
● buff: Float [1:5]
We hope this makes it easier to understand nn_module
objects.
We have also improved autocomplete support for nn_modules
and we will now
show all sub-modules, parameters and buffers while you type.
torchaudio
torchaudio
is an extension for torch
developed by Athos Damiani (@athospd
), providing audio loading, transformations, common architectures for signal processing, pre-trained weights and access to commonly used datasets. An almost literal translation from PyTorch’s Torchaudio library to R.
torchaudio
is not yet on CRAN, but you can already try the development version
available here.
You can also visit the pkgdown
website for examples and reference documentation.
Other features and bug fixes
Thanks to community contributions we have found and fixed many bugs in torch
.
We have also added new features including:
You can see the full list of changes in the NEWS.md file.
Thanks very much for reading this blog post, and feel free to reach out on GitHub for help or discussions!
The photo used in this post preview is by Oleg Illarionov on Unsplash