dmlc/dgl

Mac MPS support

Open

#4,725 opened on Oct 17, 2022

View on GitHub
聽(7 comments)聽(2 reactions)聽(0 assignees)Python聽(12,665 stars)聽(2,928 forks)batch import
feature requesthelp wanted

Description

馃悰 Bug

I installed pytorch with GPU support following this tutorial And I can confirm that GPU is working:

import torch
import math

dtype = torch.float
device = torch.device("mps") # --> Apple Metal framework integration for GPU

# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Randomly initialize weights
a = torch.randn((), device=device, dtype=dtype)
b = torch.randn((), device=device, dtype=dtype)
c = torch.randn((), device=device, dtype=dtype)
d = torch.randn((), device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(2000):
    # Forward pass: compute predicted y
    y_pred = a + b * x + c * x**2 + d * x**3

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    if t % 100 == 99:
        print(t, loss)

    # Backprop to compute gradients of a, b, c, d with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_a = grad_y_pred.sum()
    grad_b = (grad_y_pred * x).sum()
    grad_c = (grad_y_pred * x**2).sum()
    grad_d = (grad_y_pred * x**3).sum()

    # Update weights using gradient descent
    a -= learning_rate * grad_a
    b -= learning_rate * grad_b
    c -= learning_rate * grad_c
    d -= learning_rate * grad_d


print(f"Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3")

Outputs:

Result: y = 0.050219081342220306 + 0.8358809351921082 x + -0.008663627319037914 x^2 + -0.0903632640838623 x^3

Then I try to use this GPU on dgl:

graph = ....
gpu = torch.device("mps")
graph.to(gpu)

Output:

Traceback (most recent call last):
  File "/Users/diogosilva/code/ubiwhere/models-hub/graph.py", line 60, in <module>
    graph.to(gpu)
  File "/Users/diogosilva/.pyenv/versions/models-hub/lib/python3.9/site-packages/dgl/heterograph.py", line 5448, in to
    ret._graph = self._graph.copy_to(utils.to_dgl_context(device))
  File "/Users/diogosilva/.pyenv/versions/models-hub/lib/python3.9/site-packages/dgl/utils/internal.py", line 534, in to_dgl_context
    device_type = nd.DGLContext.STR2MASK[F.device_type(ctx)]
KeyError: 'mps'

Expected behavior

It should behave like any other GPU device.

Environment

  • DGL Version: 0.9.1
  • Backend Library & Version: Python Version 1.14.0.dev20221017
  • OS (e.g., Linux): MacOS
  • How you installed DGL (conda, pip, source): pip
  • Python version: 3.9.1

Additional context

Contributor guide

Mac MPS support 路 dmlc/dgl#4725 | Good First Issue