DZone
Thanks for visiting DZone today,
Edit Profile
  • Manage Email Subscriptions
  • How to Post to DZone
  • Article Submission Guidelines
Sign Out View Profile
  • Post an Article
  • Manage My Drafts
Over 2 million developers have joined DZone.
Log In / Join
Refcards Trend Reports Events Over 2 million developers have joined DZone. Join Today! Thanks for visiting DZone today,
Edit Profile Manage Email Subscriptions Moderation Admin Console How to Post to DZone Article Submission Guidelines
View Profile
Sign Out
Refcards
Trend Reports
Events
Zones
Culture and Methodologies Agile Career Development Methodologies Team Management
Data Engineering AI/ML Big Data Data Databases IoT
Software Design and Architecture Cloud Architecture Containers Integration Microservices Performance Security
Coding Frameworks Java JavaScript Languages Tools
Testing, Deployment, and Maintenance Deployment DevOps and CI/CD Maintenance Monitoring and Observability Testing, Tools, and Frameworks
Culture and Methodologies
Agile Career Development Methodologies Team Management
Data Engineering
AI/ML Big Data Data Databases IoT
Software Design and Architecture
Cloud Architecture Containers Integration Microservices Performance Security
Coding
Frameworks Java JavaScript Languages Tools
Testing, Deployment, and Maintenance
Deployment DevOps and CI/CD Maintenance Monitoring and Observability Testing, Tools, and Frameworks
  1. DZone
  2. Coding
  3. Frameworks
  4. Migrating to PyTorch

Migrating to PyTorch

There are three key things that you need to become accustomed to when using PyTorch: dataloading, model design, and tensor residency.

Christopher Lamb user avatar by
Christopher Lamb
CORE ·
Feb. 20, 19 · Analysis
Like (1)
Save
Tweet
Share
3.74K Views

Join the DZone community and get the full member experience.

Join For Free

So I do a fair amount of deep learning work, and I recently migrated from Keras to PyTorch. It wasn't the smoothest transition, but there are a few things PyTorch does that I like. And a few that I don't.

When I started with PyTorch, I didn't have anybody around to point these things out to me, so I thought it might be nice if I pointed some of these things out to you if you were in a similar position.

If you've worked with various deep learning frameworks for a bit and had to deal with more than one, you've discovered that they all have their quirks. TensorFlow had the whole tensor thing and the ability to defer model execution, which was new to lots of us. Keras is nice, as it abstracts a lot of the TensorFlow details and lets you focus on the model, but you end up spending a lot of time tuning dimensions until everything lines up. PyTorch keeps the tensor abstraction, which I'm used to by now, and is as general as Keras, but not as hung up on dimensionality, which is good and bad. And of course, all the frameworks use upper camel case. Which is important, of course.

There are three key things that you need to become accustomed to when using PyTorch: Dataloading, model design, and tensor residency.

First, PyTorch uses this abstraction where DataLoaders load data from DataSets. It's actually really handy — you can set the batch size, the number of workers (if you're using CUDA), whether the data is randomized, and a few other things. Implementing a DataSet to load your data into the loader is straightforward too — essentially, you create an object that supports direct indexing, and the loader takes care of the rest. All you need to do is provide access to your data, and PyTorch takes it from there. PyTorch also provides access to a range of typical data repositories, including CIFAR and MNIST, to allow you to easily validate initial model designs.

Second, PyTorch handles model building details for you. In Keras, you need to explicitly align the dimensions of your layers until things fit together appropriately. This is a fine way to approach model building, but PyTorch takes a completely different approach. Instead of using layer dimensions, PyTorch uses a channel concept. Think of a typical RGB valued color pixel — in this pixel, each color in the 3-tuple is a channel. You don't explicitly define dimensionality, you just define channel handling, and PyTorch takes care of dimension aligning for you (if it is possible). The only time I've had issues with this is when executing a linear transform over a convolutional layer.

Finally, tensor residency. So, I generally execute production runs over complete datasets on GPUs, but develop my model architecture using CPU. CPU is certainly slower, but I'm usually more concerned on CPU with making sure all the data types are correct and that the model will execute correctly over the input data. Tensors are typed, however, and in order to move tensor data from the GPU (where you're executing it) to, say, scikit's confusion_matrix(.) function, you need to move the data from GPU to primary memory. It'sasy enough to do, you just call the .cpu() method on the tensor, but it can be hard to find out where this needs to be done.

Overall, PyTorch provides a strong environment for deep learning work. But like all frameworks, it certainly has its quirks.

PyTorch

Opinions expressed by DZone contributors are their own.

Popular on DZone

  • How to Submit a Post to DZone
  • The Quest for REST
  • How To Create and Edit Excel XLSX Documents in Java
  • Fraud Detection With Apache Kafka, KSQL, and Apache Flink

Comments

Partner Resources

X

ABOUT US

  • About DZone
  • Send feedback
  • Careers
  • Sitemap

ADVERTISE

  • Advertise with DZone

CONTRIBUTE ON DZONE

  • Article Submission Guidelines
  • Become a Contributor
  • Visit the Writers' Zone

LEGAL

  • Terms of Service
  • Privacy Policy

CONTACT US

  • 600 Park Offices Drive
  • Suite 300
  • Durham, NC 27709
  • support@dzone.com
  • +1 (919) 678-0300

Let's be friends: