Migrating to PyTorch
There are three key things that you need to become accustomed to when using PyTorch: dataloading, model design, and tensor residency.
Join the DZone community and get the full member experience.Join For Free
So I do a fair amount of deep learning work, and I recently migrated from Keras to PyTorch. It wasn't the smoothest transition, but there are a few things PyTorch does that I like. And a few that I don't.
When I started with PyTorch, I didn't have anybody around to point these things out to me, so I thought it might be nice if I pointed some of these things out to you if you were in a similar position.
If you've worked with various deep learning frameworks for a bit and had to deal with more than one, you've discovered that they all have their quirks. TensorFlow had the whole tensor thing and the ability to defer model execution, which was new to lots of us. Keras is nice, as it abstracts a lot of the TensorFlow details and lets you focus on the model, but you end up spending a lot of time tuning dimensions until everything lines up. PyTorch keeps the tensor abstraction, which I'm used to by now, and is as general as Keras, but not as hung up on dimensionality, which is good and bad. And of course, all the frameworks use upper camel case. Which is important, of course.
There are three key things that you need to become accustomed to when using PyTorch: Dataloading, model design, and tensor residency.
First, PyTorch uses this abstraction where DataLoaders load data from DataSets. It's actually really handy — you can set the batch size, the number of workers (if you're using CUDA), whether the data is randomized, and a few other things. Implementing a DataSet to load your data into the loader is straightforward too — essentially, you create an object that supports direct indexing, and the loader takes care of the rest. All you need to do is provide access to your data, and PyTorch takes it from there. PyTorch also provides access to a range of typical data repositories, including CIFAR and MNIST, to allow you to easily validate initial model designs.
Second, PyTorch handles model building details for you. In Keras, you need to explicitly align the dimensions of your layers until things fit together appropriately. This is a fine way to approach model building, but PyTorch takes a completely different approach. Instead of using layer dimensions, PyTorch uses a channel concept. Think of a typical RGB valued color pixel — in this pixel, each color in the 3-tuple is a channel. You don't explicitly define dimensionality, you just define channel handling, and PyTorch takes care of dimension aligning for you (if it is possible). The only time I've had issues with this is when executing a linear transform over a convolutional layer.
Finally, tensor residency. So, I generally execute production runs over complete datasets on GPUs, but develop my model architecture using CPU. CPU is certainly slower, but I'm usually more concerned on CPU with making sure all the data types are correct and that the model will execute correctly over the input data. Tensors are typed, however, and in order to move tensor data from the GPU (where you're executing it) to, say, scikit's confusion_matrix(.) function, you need to move the data from GPU to primary memory. It'sasy enough to do, you just call the .cpu() method on the tensor, but it can be hard to find out where this needs to be done.
Overall, PyTorch provides a strong environment for deep learning work. But like all frameworks, it certainly has its quirks.
Opinions expressed by DZone contributors are their own.