Shong

Shong

AMP Mixed Precision Training

What is AMP#

By default, most deep learning frameworks use single precision (32-bit floating point) for training.

In 2017, NVIDIA combined single precision and half precision (16-bit floating point) during network training, achieving nearly the same accuracy as single precision with the same hyperparameters.

Half precision: 16bit, 1 bit sign bit, 5 bit exponent bit, 10 bit fraction bit

Single precision: 32bit, 1 bit sign bit, 8 bit exponent bit, 23 bit fraction bit

In PyTorch, there are a total of 10 types of tensors

The default tensor is torch.FloatTensor(32bit floating point)

torch.FloatTensor(32bit floating point)
torch.DoubleTensor(64bit floating point)
torch.HalfTensor(16bit floating point)
torch.BFloat16Tensor(16bit floating point)
torch.ByteTensor(8bit integer(unsigned)
torch.CharTensor(8bit integer(signed))
torch.ShortTensor(16bit integer(signed))
torch.IntTensor(32bit integer(signed))
torch.LongTensor(64bit integer(signed))
torch.BoolTensor(Boolean)

Automatic mixed precision has two key points

  • Automatic: The dtype of the tensor will change automatically, and the framework automatically adjusts the dtype of the tensor, sometimes requiring manual intervention.
  • Mixed precision: Using tensors of more than one precision, torch.FloatTensor and torch.HalfTensor.

Why use AMP#

Core: In some cases, FP16 is better, while in others, FP32 is better.

There are three advantages of FP16:

  • Reduces memory usage
  • Accelerates training and inference (significantly reduces communication volume, speeding up data flow)
  • The popularity of tensor cores; low precision computation is an important trend.

Two major issues with FP16:

  • Overflow errors: FP16 has a very narrow dynamic range, making it prone to overflow and underflow. Once overflow occurs, it can easily lead to "NAN" issues. In deep learning, the gradients of activation functions are often smaller than the gradients of weights, making underflow more likely. The smallest number that FP16 can represent is 2242^{-24}, which can prevent weight updates.
  • Rounding errors: When the gradient is too small, smaller than the minimum interval within the current range, the gradient update may fail.
    For example: Under FP16, if the weight is 232^{-3} and the gradient is 2142^{-14}, the updated weight becomes 23+214=232^{-3}+2^{-14}=2^{-3}, because the fixed interval for FP16 is 2132^{-13}, small gradients will be treated as no update.
If you don't understand this part, please refer to the last section: Data Representation.

😤 Therefore, to eliminate the issues with FP16, there are two solutions.

Mixed Precision Training#

Use FP16 for storage and multiplication in memory to accelerate computation, while using FP32 for accumulation to avoid rounding errors.

The strategy of mixed precision training effectively alleviates the problem of rounding errors.

Loss Scaling#

Even with mixed precision training, there can still be situations where convergence fails due to the activation gradient being too small, leading to underflow.

This can be prevented by using torch.cuda.amp.GradScaler, which scales the loss value to prevent gradient underflow. ( 🤓 Note: The scaling of loss here is only used when passing gradient information during BP; when actually updating weights, the scaled gradients need to be reduced back.)

How to use AMP#

from torch.cuda.amp import autocast as autocast

model=Net().cuda()
optimizer=optim.SGD(model.parameters(),...)

scaler = GradScaler() # Instantiate a GradScaler object before training

for epoch in epochs:
  for input,target in data:
    optimizer.zero_grad()

    with autocast():  # Enable autocast before and after
      output=model(input)
      loss = loss_fn(output,target)

    scaler.scale(loss).backward()  # For gradient scaling
    # scaler.step() First unscale the gradient value; if the gradient value is not inf or NaN, call optimizer.step() to update weights; otherwise, ignore the step call to ensure weights are not updated.
    scaler.step(optimizer)
    scaler.update()  # Prepare to see if the scaler needs to be increased

The size of the scaler is dynamically estimated in each iteration, and to minimize gradient underflow, the scaler should gradually increase.

However, if it gets too large, half-precision floating points can easily overflow (turn into inf or NaN).

Thus, the principle of dynamic estimation is to maximize the scaler value without encountering inf or NaN gradients.

In each scaler.step(optimizer), it checks for the occurrence of inf or NaN gradients:

  • If inf or NaN occurs, scaler.step(optimizer) will ignore the weight update (optimizer.step()) and reduce the size of the scaler (multiply by backoff_factor).
  • If no inf or NaN occurs, weights are updated normally, and when there are multiple consecutive iterations (as specified by growth_interval) without inf or NaN, scaler.update() will increase the size of the scaler (multiply by growth_factor).

For distributed training, since autocast is thread-local (meaning the behavior and state of autocast are specific to the current independent thread),

For torch.nn.DataParallel and torch.nn.DistributedDataParallel,

it cannot be used as follows:

model = MyModel()
dp_model = nn.DataParallel(model)

with autocast():
    output=dp_model(input)
loss=loss_fn(output)

Instead, it should be used as follows:

MyModel(nn.Module):
    @autocast()
    def forward(self, input):
        ...
        
# alternatively
MyModel(nn.Module):
    def forward(self, input):
        with autocast():
            ...

model = MyModel()
dp_model=nn.DataParallel(model)

with autocast():
    output=dp_model(input)
    loss = loss_fn(output)
  • Ensure that autocast is present in every forward to guarantee that each thread operates under autocast.
  • Loss also needs to be used under autocast.

Important Notes#

  • Check if the GPU supports FP16.
  • Constant range: To ensure calculations do not overflow, first ensure that manually set epsilon and INF do not overflow.
  • Dimensions are best as multiples of 8 for optimal performance ( 🤯).
  • Operations involving sum are prone to overflow; for softmax operations, it is recommended to use the official API and define it as a layer in the model initialization.
  • Some less commonly used functions need to be registered before use: 🌰 amp.register_float_function(torch, ‘sigmoid’).
  • Layers should be defined in the model's init function, while graphs should be defined in the forward function.
  • Certain functions do not support FP16 acceleration; it is advisable not to use them.
  • Modules that require gradient operations must be within the optimizer's step; otherwise, AMP cannot determine if grad is NaN.

Data Representation#

Von Neumann architecture: binary conception + five components (memory, controller, arithmetic unit, input, output)

Harvard architecture: The biggest difference is simultaneous access to data and instructions; the ARM architecture is a Harvard architecture.

Overflow Issues#

(lldb) print (233333 + 1) * (233333 + 1)
(int) $0 = -1389819292

(x+1)20(x+1)^2 \ge 0 is not always true because integers can overflow; int is only 32 bits.

The representation of floating-point numbers is different from integers and does not turn negative due to overflow, but it has its own issues.

(lldb) print (1e20 + -1e20) + 3.14
(double) $0 = 3.1400000000000001
(lldb) print 1e20 + (-1e20 + 3.14)
(double) $1 = 0

This is due to the differences in floating-point addition and subtraction, which will be explained in detail below.

Bit Representation#

Everything seen in computers is bits; each bit is either 0 or 1. Computers achieve different tasks by encoding and describing bits in various ways.

From the perspective of analog circuits, this method of description is well-stored and can maintain a relatively high reliability even in the presence of noise or less accurate transmission.

Integer#

signed and unsigned

  • Unsigned number: B2U(X)=i=0w1xi2iB2U(X)= \sum ^{w-1}_{i=0}x_i*2^i
  • Signed number: B2T(X)=xw12w1+i=0w2xi2iB2T(X)= -x_{w-1}*2^{w-1}+\sum ^{w-2}_{i=0}x_i*2^i

The main difference between signed and unsigned numbers lies in the presence of a sign bit in the highest position.

🤓 When converting between signed and unsigned numbers:

  • The value of each byte does not change; what changes is how the computer interprets the current value.
  • If an expression includes both signed and unsigned numbers, it will be implicitly converted to an unsigned number for comparison.

Type Extension and Truncation#

  • Extension: For example, from short int to int
    • Unsigned number: add 0
    • Signed number: add sign bit
  • Truncation: For example, from unsigned to unsigned short, for small numbers, expected results can be obtained.
    • Unsigned number: mod operation
    • Signed number: approximate mod operation
short int x = 15213;
int ix = (int) x;
short int y = -15213;
int iy = (int) y;
DecimalHexadecimalBinary
x=152133B 6D00111011 01101101
ix=1521300 00 3B 6D00000000 00000000 00111011 01101101
y=-15213C4 9311000100 10010011
iy=-15213FF FF C4 9311111111 11111111 11000100 10010011

Integer operations and overflow

  • Signed numbers: Overflow is when the sign bit changes; positive becomes negative, negative becomes positive.
  • Unsigned numbers: Overflow is when the high bit becomes 0, meaning what you intended to add actually becomes smaller.

Floating Point#

Floating-point numbers can be expressed with a unified formula:

k=jibk2k\sum_{k=-j}^{i}b_k*2^k

It can be seen that only numbers of the form x2k\frac{x}{2^k} can be precisely represented.

IEEE floating-point standard:

(1)sM2E(-1)^sM2^E

Where s is the sign bit, determining the sign; M is usually a value in [1.0, 2.0), and E is the exponent.

Floating Point

Normalized values: When exp00,..,0and111,...,1exp \neq 00,..,0 \enspace and \enspace 111,...,1, they represent normalized values.

E is an offset value E=ExpBiasE=Exp-Bias

  • ExpExp: is the unsigned value of the exp encoding area.
  • BiasBias: is the offset value of 2k112^{k-1}-1, where k is the number of bits in the exp encoding, meaning that
    • Single precision: 127
    • Double precision: 1023

For M, it must start with 1: that is, M=1.xxxx...x2M=1.xxxx...x_2, where xxxxxx is the encoded part of frac.

For example:

float F = 15213.0;

1521310=111011011011012=1.1101101101101221315213_{10}=11101101101101_{2}=1.1101101101101_{2} *2^{13}

The value of the frac part is the digits after the decimal point: 1101101101101.

Exp=E+Bias=13+127=140=100011002Exp = E + Bias = 13 + 127 = 140 =10001100_2

sexpfrac
01000110011011011011010000000000

🧐 Remember the non-normalized values mentioned earlier?

When exp=000,..,000exp = 000,..,000, the value is non-normalized, meaning that the originally continuous values on the real number axis are mapped to a finite set of fixed values, and the spacing of these fixed values is also uniform.

Unlike before, M=0.xxxx...x2M = 0.xxxx...x_2.

  • When exp=000...0exp=000...0, E=1BiasE=1-Bias
    • frac=000...0frac=000...0 means 0.
    • frac000...0frac\neq 000...0 means a value close to 0.
  • When exp=111...1exp=111...1, E=n/aE = n/a
    • frac=000...0frac=000...0 means \infin.
    • frac000...0frac\neq 000...0 means it is not considered a value, used to represent an indeterminate value (NaN).

Now I will use the number line to illustrate this issue 😤

Number Line

Using the following 🌰 to illustrate this issue:

    s exp  frac   E   Value
------------------------------------------------------------------
    0 0000 000   -6   0   # This part is non-normalized, the next part is normalized values
    0 0000 001   -6   1/8 * 1/64 = 1/512 # The closest value to zero that can be represented
    0 0000 010   -6   2/8 * 1/64 = 2/512 
    ...
    0 0000 110   -6   6/8 * 1/64 = 6/512
    0 0000 111   -6   7/8 * 1/64 = 7/512 # The largest non-normalized value that can be represented
------------------------------------------------------------------
    0 0001 000   -6   8/8 * 1/64 = 8/512 # The smallest normalized value that can be represented
    0 0001 001   -6   9/8 * 1/64 = 9/512
    ...
    0 0110 110   -1   14/8 * 1/2 = 14/16
    0 0110 111   -1   15/8 * 1/2 = 15/16 # The closest value to 1 that is less than 1
    0 0111 000    0   8/8 * 1 = 1
    0 0111 001    0   9/8 * 1 = 9/8      # The closest value to 1 that is greater than 1
    0 0111 010    0   10/8 * 1 = 10/8
    ...
    0 1110 110    7   14/8 * 128 = 224
    0 1110 111    7   15/8 * 128 = 240   # The largest normalized value that can be represented
------------------------------------------------------------------
    0 1111 000   n/a  infinity               # Special value

Floating Point Rounding#

For floating-point addition and multiplication, we can first calculate the exact value and then convert it to the appropriate precision.

  Decimal    Binary     Rounded Result  Decimal    Reason
2 and 3/32  10.00011   10.00     2      Less than half, normal rounding
2 and 3/16  10.00110   10.01  2 and 1/4   Exceeds half, normal rounding
2 and 7/8   10.11100   11.00     3      Exactly at half, ensure the last digit is even, so round up
2 and 5/8   10.10100   10.10  2 and 1/2   Exactly at half, ensure the last digit is even, so round down

Floating Point Addition#

(1)s1M12E1+(1)s2M22E2(-1)^{s_1}M_12^{E_1}+(-1)^{s_2}M_22^{E_2}

Assuming E1>E2E_1>E_2, the result is (1)sM2E(-1)^{s}M2^{E}, where s=s1s2,M=M1+M2,E=E1s=s_1 \wedge s_2, \enspace M=M_1+M_2,\enspace E=E_1.

  • If M2M\ge 2, then shift M right and increase the value of E.
  • If M<1M< 1, shift M left k bits, and decrease E by k.
  • If E exceeds the representable range, overflow occurs.
  • Round M to the precision of frac.

Basic properties:

  • Addition may produce infinity or NaN.
  • Satisfies the commutative law.
  • Does not satisfy the associative law.
  • Adding 0 equals the original number.
  • Except for infinity or NaN, each element has a corresponding reciprocal.
  • Except for infinity or NaN, monotonicity is satisfied.

Floating Point Multiplication#

(1)s1M12E1(1)s2M22E2(-1)^{s_1}M_12^{E_1}*(-1)^{s_2}M_22^{E_2}

The result is (1)sM2E(-1)^{s}M2^{E}, where s=s1s2,M=M1M2,E=E1+E2s=s_1 \wedge s_2, \enspace M=M_1*M_2,\enspace E=E_1+E_2.

  • If M2M\ge 2, then shift M right and increase the value of E.
  • If E exceeds the representable range, overflow occurs.
  • Round M to the precision of frac.

Basic properties:

  • Addition may produce infinity or NaN.
  • Satisfies the commutative law.
  • Does not satisfy the associative law.
  • Multiplying by 1 equals the original number.
  • Except for infinity or NaN, monotonicity is satisfied.
Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.