SoFunction
Updated on 2025-03-03

Detailed explanation of PyTorch  Embedding

After tokenize and map the text sequence, the string sequence is converted into a sequence of numbers. These token ids can be directly input into the model, but it should be understood that the model cannot directly obtain rich information from a pure number. By analogy to human cognition, we understand that a word or word does not rely solely on symbols, but the meaning behind it.

Embedding layer

(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)

A simple lookup table that stores embeddings of a fixed dictionary and size.

A simple lookup table for storing embedded vectors for each word in a fixed-size dictionary.

parameter

  • num_embeddings(int): The size of the embedded dictionary, that is, the size of the vocab list (vocab size).
  • embedding_dim(int): The dimension size of each embedded vector.
  • padding_idx(int, optional): Specifies the fill corresponding index value. The embedding vector corresponding to this index will not be updated during training, that is, the gradient does not participate in backpropagation and is usually used as a "fill" mark. For newly built Embedding modules, the embedding vector for this index is all zero by default, but can be changed to other values.
  • max_norm(float, optional): If set, the embedded vector norm exceeding this value will be re-normalized so that its maximum norm is equal tomax_norm
  • norm_type(float, optional): for calculationmax_normThe p-norm of , defaults to 2, that is, the 2 norm is calculated.
  • scale_grad_by_freq(bool, optional): ifTrue, the gradient will be scaled according to the inverse of the frequency of the word in mini-batch, suitable for gradient adjustment of high-frequency words. Default isFalse
  • sparse(bool, optional): If set toTrue, the gradient of the weight matrix is ​​a sparse tensor, which is suitable for memory optimization of large-scale vocabulary.
  • variableweight(Tensor): The learnable weight of the module, with a shape of(num_embeddings, embedding_dim), the initial value is distributed from normalN(0, 1)Medium sampling.

method

from_pretrained(embeddings, freeze=True, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)

Create Embedding instance from given 2-dimensional FloatTensor.

Used to create a FloatTensor from a given 2-dimensional floating point tensorEmbeddingExample.

parameter

  • embeddings(Tensor): A containing embedded weightsFloatTensor. The first dimension representsnum_embeddings(Glossary size), the second dimension representsembedding_dim(Embed into vector dimensions).
  • freeze(bool, optional): ifTrue, the embedded tensor remains unchanged during training, which is equivalent to setting.requires_grad = False. The default value isTrue
  • The remaining parameters are defined previously.

Key points are not to be continued... (It is expected to be uploaded before 11.6)

QA

Q1: For neural networks, what is "symbol" and its "meaning behind it"?

The answer is:Token IDandEmbedding

So, what is Embedding?

We can use theTo understand it, skip the tedious introduction first and run the code to intuitively feel it:

import torch
import  as nn
# Set random seeds to ensure results are reproducibletorch.manual_seed(42)
# Define embedding layer parametersnum_embeddings = 5  # Assume there are 5 tokens in the vocabularyembedding_dim = 3   # Each token corresponds to a 3-dimensional embedding vector# Initialize the embedding layerembedding = (num_embeddings, embedding_dim)
# Define integer indexinput_indices = ([0, 2, 4])
# Find embed vectorsoutput = embedding(input_indices)
# Print the resultsprint("Weight Matrix:")
print()
print("\nEmbedding Output:")
print(output)

Output:

Weight matrix:
tensor([[ 0.3367,  0.1288,  0.2345],
        [ 0.2303, -1.1229, -0.1863],
        [ 2.2082, -0.6380,  0.4617],
        [ 0.2674,  0.5349,  0.8094],
        [ 1.1103, -1.6898, -0.9890]])

Embedding output:
tensor([[ 0.3367,  0.1288,  0.2345],
        [ 2.2082, -0.6380,  0.4617],
        [ 1.1103, -1.6898, -0.9890]], grad_fn=<EmbeddingBackward0>)

Here,input_indices = [0, 2, 4]Select rows 0, 2 and 4 from the weight matrix as the corresponding embedding representation. Yes, it's that simple to get Embedding.

Next, build an Embedding class to understand:

class Embedding():
    def __init__(self, num_embeddings, embedding_dim):
         = ((num_embeddings, embedding_dim))
    def forward(self, indices):
        return [indices]  # That's right,Return to the corresponding line

It can be seen that the essence of the Embedding class is a lookup table. In the example above,5 stored in thenum_embeddings) Embed vectors, each vector has 3 dimensions (embedding_dim). When providedinput_indicesWhen the lookup table returns the corresponding embedding vector (row of the weight matrix).

Q2: What was the initial weight matrix? What determines the final embedding vector?

The initial weight matrix is ​​generally randomly initialized, and the weight will be updated during training so that it can effectively express the meaning behind it.

Q3: What is semantics?

Let’s give a simple example to understand the “semantic” relationship: the representations of “cats” and “dogs” in vector space should be very close because they are both pets; the vector difference between “men” and “women” may represent the difference in gender. In addition, vocabulary of different languages, such as "man" (Chinese) and "man" (English), will also be very close if in the same embedding space, reflecting the semantic similarity across languages. At the same time, the differences between ["woman" and "woman" (Chinese-English)] and ["man" (Chinese-English)" may also be very similar.

This article interprets Embedding, which appears with Token id in a "narrow sense", a concept that has a more specific name in natural language processing (NLP): Word Embedding.

This is the end of this article about the detailed explanation of PyTorch () embedding. For more related PyTorch () embedding content, please search for my previous articles or continue browsing the related articles below. I hope everyone will support me in the future!