How to force random numbers to be generated on CPU? #9691
-
|
I would like to generate random numbers with >>> a = jax.random.poisson(jax.random.PRNGKey(0), 3, shape=(1000,))
>>> a.device()
TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)The results are generated on the default device (i.e. TPU). However, I would like the results to be generated on CPU. One way I can think of is to use >>> a = jax.jit(lambda: jax.random.poisson(jax.random.PRNGKey(0), 3, shape=(1000,)), backend='cpu')()
>>> a.device()
CpuDevice(id=0)However, this method is slower due to the jit compilation. What is the best way of doing this? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 2 replies
-
|
what about generate = jax.jit(lambda key: jax.random.poisson(key, 3, shape=(1000,)), backend='cpu')
key = jax.random.PRNGKey(0)
for _ in range(100):
key, subkey = jax.random.split(key)
generate(subkey)only need one compilation. |
Beta Was this translation helpful? Give feedback.
-
|
I am expecting something like this: >>> key = jax.random.PRNGKey(0, backend='cpu')
>>> key, subkey = jax.random.split(key)
>>> a = jax.random.poisson(subkey, 3, shape=(1000,), backend='cpu') |
Beta Was this translation helpful? Give feedback.
-
|
This issue can be resolved by using the default device context manager introduced in #9118: import jax
device_cpu = jax.devices('cpu')[0]
with jax.default_device(device_cpu):
a = jax.random.poisson(jax.random.PRNGKey(0), 3, shape=(1000,))
print(a.device()) # TFRT_CPU_0 |
Beta Was this translation helpful? Give feedback.
This issue can be resolved by using the default device context manager introduced in #9118: