JAX `vmap` 对于多个参数的意外行为

问题内容
,我发现 jax 中的 vmap 在应用于多个参数时不会按预期运行。例如,考虑下面的函数:,对于 x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3),该函数的输出形状为 (7, 5, 3)。但是,对于以下 vmap 版本:,它输出此错误:,有人可以解释一下这个错误背后的原因吗?,vmap 的语义是它对一个或多个数组执行单个批处理操作。当您指定 in_axes=(none, 0, 0) 时,含义是“同时沿 yz 的前导维度映射”:您看到的错误告诉您 yy 的前导维度具有不同的大小,因此它们不兼容批处理。,您的函数 f1 本质上使用广播来编码三个批处理操作,因此要使用 vmap 复制该逻辑,您将需要 vmap 的三个应用程序。您可以这样表达:,
返回顶部
跳到底部

Copyright 2011-2024 南京追名网络科技有限公司 苏ICP备2023031119号-6 乌徒帮 All Rights Reserved Powered by Z-BlogPHP Theme By open开发

请先 登录 再评论,若不是会员请先 注册