38. Backend
@six.add_metaclass(abc.ABCMeta)
class Backend(object):
"""Abstract base class for XLA backends."""
def __init__(self, platform):
"""Creates a new Backend.
Args:
platform: A string naming the platform; for example 'gpu'.
"""
self.platform = platform
https://github.com/tensorflow/tensorflow/blob/4406ec35b6056ea7a1314979
292407f1b1dd6409/tensorflow/compiler/xla/python/xla_client.py
48. JAX: Autograd and XLA
JAX is Autograd and XLA, brought together for high-performance
machine learning research.
What’s new is that JAX uses XLA to compile and run your
NumPy programs on GPUs and TPUs.
● Cloud TPU support
● Multi-GPU and multi-TPU support
49. TpuBackendは、jaxで使われていた
from . import tpu_client
def _get_tpu_driver_backend(platform):
del platform
backend_target = FLAGS.jax_backend_target
if backend_target is None:
raise ValueError('When using TPU Driver as the backend, you must specify
'
'--jax_backend_target=<hostname>:8470.')
return tpu_client.TpuBackend.create(worker=backend_target)
50. Add TPU Driver to jaxlib #1673
Add TPU Driver: a low-level TPU API as a JAX backend for Cloud TPU.
https://github.com/tensorflow/tensorflow/commit/4406ec35b6056ea7a131497
9292407f1b1dd6409