Skip to content
Discussion options

You must be logged in to vote

This issue can be resolved by using the default device context manager introduced in #9118:

import jax

device_cpu = jax.devices('cpu')[0]
with jax.default_device(device_cpu):
    a = jax.random.poisson(jax.random.PRNGKey(0), 3, shape=(1000,))
    print(a.device())  # TFRT_CPU_0

Replies: 3 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@YouJiacheng
Comment options

@ayaka14732
Comment options

Comment options

You must be logged in to vote
0 replies
Answer selected by ayaka14732
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants