question on trax

Dennis Lee Bieber wlfraed at ix.netcom.com
Wed Aug 18 12:42:59 EDT 2021


On Tue, 17 Aug 2021 17:50:59 +0200, joseph pareti <joepareti54 at gmail.com>
declaimed the following:

>In the following code, where does tl.Fn come from? i see it nowhere in the
>documents, i.e I was looking for trax.layers.Fn :

	"layers" imports a whole slew of sub modules using
		from xxx import *
in order to put all the sub module names at the same level.

https://github.com/google/trax/blob/master/trax/layers/base.py
>From line 748 on...

def Fn(name, f, n_out=1):  # pylint: disable=invalid-name
  """Returns a layer with no weights that applies the function `f`.
  `f` can take and return any number of arguments, and takes only
positional
  arguments -- no default or keyword arguments. It often uses JAX-numpy
(`jnp`).
  The following, for example, would create a layer that takes two inputs
and
  returns two outputs -- element-wise sums and maxima:
      `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)),
n_out=2)`
  The layer's number of inputs (`n_in`) is automatically set to number of
  positional arguments in `f`, but you must explicitly set the number of
  outputs (`n_out`) whenever it's not the default value 1.
  Args:
    name: Class-like name for the resulting layer; for use in debugging.
    f: Pure function from input tensors to output tensors, where each input
        tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
        Output tensors must be packaged as specified in the `Layer` class
        docstring.
    n_out: Number of outputs promised by the layer; default value 1.
  Returns:
    Layer executing the function `f`.
  """


-- 
	Wulfraed                 Dennis Lee Bieber         AF6VN
	wlfraed at ix.netcom.com    http://wlfraed.microdiversity.freeddns.org/



More information about the Python-list mailing list