nn.Embedding()个人记录

 维度

import torch.nn as nnembedding = nn.Embedding(num_embeddings = 10, embedding_dim = 256)

nn.Embedding()随机产生一个权重矩阵weight,维度为(num_embeddings, embedding_dim) 

输入维度(batch_size, Seq_len)

输出维度(batch_size,Seq_len,embedding_dim)

举例

 

参考&转载:

pytorch复习笔记--nn.Embedding()的用法-CSDN博客