optimize and support int8

This commit is contained in:
BBuf 2023-03-14 13:28:05 +00:00
parent e97663b388
commit 63e77fae38
23 changed files with 369 additions and 118 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -2,7 +2,7 @@ import math
import oneflow as torch
import oneflow.nn.functional as F
from oneflow.nn.parameter import Parameter
from ..quantization import QuantizedLinear
def fast_gelu(x):
"""Mindspore's fast gelu implementation."""
@ -13,7 +13,6 @@ def fast_gelu(x):
class MLP(torch.nn.Module):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
@ -52,7 +51,6 @@ class MLP(torch.nn.Module):
class SelfAttention(torch.nn.Module):
"""self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
@ -100,7 +98,7 @@ class SelfAttention(torch.nn.Module):
# Query, Key, and Value
# =====================
if hasattr(torch._C, 'grouped_matmul_bias'):
if hasattr(torch._C, 'grouped_matmul_bias') and not isinstance(self.query, QuantizedLinear):
query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias([hidden_states, hidden_states, hidden_states],
[self.query.weight, self.key.weight, self.value.weight],
[self.query.bias, self.key.bias, self.value.bias])
@ -108,49 +106,41 @@ class SelfAttention(torch.nn.Module):
query_layer = self.query(hidden_states)
key_layer = self.key(hidden_states)
value_layer = self.value(hidden_states)
new_query_layer_shape = query_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_query_layer_shape)
new_query_layer_shape = key_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
key_layer = key_layer.view(*new_query_layer_shape)
new_query_layer_shape = value_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
value_layer = value_layer.view(*new_query_layer_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
origin_query_layer = query_layer
origin_key_layer = key_layer
origin_value_layer = value_layer
fallback = not hasattr(torch._C, 'fused_multi_head_attention_inference_v2')
if hasattr(torch._C, 'fused_multi_head_attention_inference'):
if layer_past is not None:
context_layer = torch._C.fused_multi_head_attention_inference(
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=False
).transpose(0, 1)
if fallback:
if hasattr(torch._C, 'fused_codegeex_qkv_reshape'):
query_layer, key_layer, value_layer = torch._C.fused_codegeex_qkv_reshape(query_layer, key_layer, value_layer, self.num_attention_heads)
else:
context_layer = torch._C.fused_multi_head_attention_inference(
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=True
).transpose(0, 1)
else:
new_query_layer_shape = query_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_query_layer_shape)
new_query_layer_shape = key_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
key_layer = key_layer.view(*new_query_layer_shape)
new_query_layer_shape = value_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
value_layer = value_layer.view(*new_query_layer_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
# ===================================
# Raw attention scores. [b, np, sq, sk]
# ===================================
@ -167,7 +157,7 @@ class SelfAttention(torch.nn.Module):
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.matmul(query_layer.transpose(0, 1),
key_layer.permute(1, 2, 0)) / self.norm_factor
key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
@ -194,11 +184,22 @@ class SelfAttention(torch.nn.Module):
attention_mask = torch.clone(attention_mask)
attention_mask[:, :, context_length:, :] = True
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
attention_mask = ~attention_mask
attention_mask = attention_mask.contiguous()
# attention scores and attention mask [b, np, sq, sk]
# attention_scores = attention_mask_func(attention_scores, attention_mask)
if hasattr(torch._C, 'fused_scale_mask_softmax'):
if self.attention_softmax_in_fp32:
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores.float(), attention_mask, fill_value=-10000.0, scale=1.0).half()
else:
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores, attention_mask, fill_value=-10000.0, scale=1.0)
else:
attention_probs = self.softmax(attention_scores)
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_probs = self.softmax(attention_scores)
# =========================
# Context layer. [sq, b, hp]
@ -220,7 +221,7 @@ class SelfAttention(torch.nn.Module):
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
@ -232,10 +233,40 @@ class SelfAttention(torch.nn.Module):
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size,)
context_layer = context_layer.view(*new_context_layer_shape)
else:
if layer_past is not None:
past_key, past_value = layer_past
key_layer, value_layer = torch._C.fused_attention_concat_past_key_value(
past_key=past_key,
past_key_layout="MB(HK)",
past_value=past_value,
past_value_layout="MB(HK)",
key=key_layer,
key_layout="MB(HK)",
value=value_layer,
value_layout="MB(HK)",
key_head_size=self.hidden_size_per_attention_head,
)
if get_key_value:
present = (key_layer, value_layer)
context_layer = torch._C.fused_multi_head_attention_inference_v2(
query=query_layer,
key=key_layer,
value=value_layer,
query_head_size=self.hidden_size_per_attention_head,
causal=True,
causal_diagonal_offset=key_layer.shape[0]-query_layer.shape[0],
query_layout="MB(HK)",
key_layout="MB(HK)",
value_layout="MB(HK)",
output_layout="MB(HK)",
)
# =================
# Output. [sq, b, h]
# =================
# =================
# Output. [sq, b, h]
# =================
output = self.dense(context_layer)
@ -247,7 +278,6 @@ class SelfAttention(torch.nn.Module):
class TopQuerySelfAttention(torch.nn.Module):
"""Top query self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
@ -291,7 +321,7 @@ class TopQuerySelfAttention(torch.nn.Module):
):
# hidden_states: [sq, b, h]
if hasattr(torch._C, 'grouped_matmul_bias'):
if hasattr(torch._C, 'grouped_matmul_bias') and not isinstance(self.query, QuantizedLinear):
query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias([query_hidden_state, hidden_states, hidden_states],
[self.query.weight, self.key.weight, self.value.weight],
[self.query.bias, self.key.bias, self.value.bias])
@ -299,49 +329,41 @@ class TopQuerySelfAttention(torch.nn.Module):
query_layer = self.query(query_hidden_state)
key_layer = self.key(hidden_states)
value_layer = self.value(hidden_states)
fallback = not hasattr(torch._C, 'fused_multi_head_attention_inference_v2')
new_query_layer_shape = query_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_query_layer_shape)
new_query_layer_shape = key_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
key_layer = key_layer.view(*new_query_layer_shape)
new_query_layer_shape = value_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
value_layer = value_layer.view(*new_query_layer_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
origin_query_layer = query_layer
origin_key_layer = key_layer
origin_value_layer = value_layer
if hasattr(torch._C, 'fused_multi_head_attention_inference'):
if layer_past is not None:
context_layer = torch._C.fused_multi_head_attention_inference(
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=False
).transpose(0, 1)
if fallback:
if hasattr(torch._C, 'fused_codegeex_qkv_reshape'):
query_layer, key_layer, value_layer = torch._C.fused_codegeex_qkv_reshape(query_layer, key_layer, value_layer, self.num_attention_heads)
else:
context_layer = torch._C.fused_multi_head_attention_inference(
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=True
).transpose(0, 1)
else:
new_query_layer_shape = query_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_query_layer_shape)
new_query_layer_shape = key_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
key_layer = key_layer.view(*new_query_layer_shape)
new_query_layer_shape = value_layer.size()[:-1] + \
(self.num_attention_heads,
self.hidden_size_per_attention_head)
value_layer = value_layer.view(*new_query_layer_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=0)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=0)
if get_key_value:
present = (key_layer, value_layer)
# ===================================
# Raw attention scores. [b, np, sq, sk]
# ===================================
@ -386,18 +408,11 @@ class TopQuerySelfAttention(torch.nn.Module):
# attention scores and attention mask [b, np, sq, sk]
# attention_scores = attention_mask_func(attention_scores, attention_mask)
if hasattr(torch._C, 'fused_scale_mask_softmax'):
attention_mask = ~attention_mask
if self.attention_softmax_in_fp32:
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores.float(), attention_mask, fill_value=-10000.0, scale=1.0).half()
else:
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores, attention_mask, fill_value=-10000.0, scale=1.0)
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_scores = attention_scores - attention_mask * 10000.0
if self.attention_softmax_in_fp32:
attention_probs = self.softmax(attention_scores.float()).half()
else:
attention_probs = self.softmax(attention_scores)
attention_probs = self.softmax(attention_scores)
# =========================
# Context layer. [sq, b, hp]
@ -433,9 +448,40 @@ class TopQuerySelfAttention(torch.nn.Module):
(self.hidden_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
else:
if layer_past is not None:
past_key, past_value = layer_past
key_layer, value_layer = torch._C.fused_attention_concat_past_key_value(
past_key=past_key,
past_key_layout="MB(HK)",
past_value=past_value,
past_value_layout="MB(HK)",
key=key_layer,
key_layout="MB(HK)",
value=value_layer,
value_layout="MB(HK)",
key_head_size=self.hidden_size_per_attention_head,
)
if get_key_value:
present = (key_layer, value_layer)
if hasattr(torch._C, 'fused_multi_head_attention_inference_v2'):
context_layer = torch._C.fused_multi_head_attention_inference_v2(
query=query_layer,
key=key_layer,
value=value_layer,
query_head_size=self.hidden_size_per_attention_head,
causal=True,
causal_diagonal_offset=key_layer.shape[0]-query_layer.shape[0],
query_layout="MB(HK)",
key_layout="MB(HK)",
value_layout="MB(HK)",
output_layout="MB(HK)",
)
# =================
# Output. [sq, b, h]
# =================
output = self.dense(context_layer)
@ -447,7 +493,6 @@ class TopQuerySelfAttention(torch.nn.Module):
class TransformerLayer(torch.nn.Module):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
"""
@ -527,7 +572,6 @@ class TransformerLayer(torch.nn.Module):
class TopQueryLayer(torch.nn.Module):
"""A single top query layer.
Top query layer takes input with size [b, s, h] and returns an
output of the same size.
"""
@ -728,7 +772,6 @@ class Transformer(torch.nn.Module):
class Embedding(torch.nn.Module):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
@ -808,7 +851,6 @@ class Embedding(torch.nn.Module):
class QueryEmbedding(torch.nn.Module):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
@ -868,7 +910,6 @@ class QueryEmbedding(torch.nn.Module):
class TransformerLanguageModel(torch.nn.Module):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`

View File

@ -1 +1,3 @@
from .quantize import quantize
from .quantize import quantize
from .quantize_oneflow import quantize_oneflow
from .quantize_oneflow import QuantizedLinear

View File

@ -0,0 +1,168 @@
import numpy as np
import oneflow as torch
from oneflow.nn.parameter import Parameter
def _pack_int8_to_int4(x):
np_x = x.numpy()
l = np_x[..., 0::2]
r = np_x[..., 1::2]
l = np.left_shift(l, 4)
if x.dtype is np.int8:
even = np.bitwise_and(r, np.int8(0xF))
packed = torch.tensor(np.bitwise_or(l, r), device=x.device)
return packed
def _quantize(num_bits, symmetric, x, group_dim, group_size, quant_type):
x_float = x.float()
x_reshaped = x_float.reshape(
x.shape[:group_dim]
+ (x.shape[group_dim] // group_size, group_size)
+ x.shape[group_dim + 1 :]
)
if symmetric:
signed_max = float(2 ** (num_bits - 1)) - 1
offset = signed_max if quant_type is torch.uint8 else 0.0
scale_float = (
x_reshaped.abs().max(dim=group_dim + 1, keepdim=True).values / signed_max
)
quantized = (
torch.round(x_reshaped / scale_float + offset)
.reshape(x.shape)
.to(quant_type)
)
if num_bits == 4:
quantized = _pack_int8_to_int4(quantized)
return (quantized, scale_float.squeeze(group_dim + 1).to(x.dtype), None)
else:
unsigned_max = float(2 ** num_bits) - 1
mn = x_reshaped.min(dim=group_dim + 1, keepdim=True).values
mx = x_reshaped.max(dim=group_dim + 1, keepdim=True).values
scale_float = (mx - mn) / unsigned_max
quantized = (
torch.round((x_reshaped - mn) / scale_float).reshape(x.shape).to(torch.uint8)
)
if num_bits == 4:
quantized = _pack_int8_to_int4(quantized)
return (
quantized,
scale_float.squeeze(group_dim + 1).to(x.dtype),
mn.squeeze(group_dim + 1).to(x.dtype),
)
class QuantizedLinear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
weight_bit_width: int,
weight: torch.Tensor = None,
bias: torch.Tensor = None,
*args,
**kwargs
):
super(QuantizedLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight_bit_width = weight_bit_width
self.symmetric = True
self.group_dim = 1
self.group_size = in_features
self.weight, self.weight_scale, self.weight_zero = _quantize(
self.weight_bit_width, self.symmetric, weight, self.group_dim, self.group_size, torch.int8
)
if bias is None:
self.register_parameter('bias', None)
else:
self.bias = bias
self.bias = self.bias.to(kwargs["device"])
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
if self.bias is not None:
self.bias = Parameter(self.bias.to(kwargs["device"]), requires_grad=False)
if self.weight_zero is not None:
self.weight_zero = Parameter(self.weight_zero.to(kwargs["device"]), requires_grad=False)
def forward(self, input_):
# Matrix multiply.
output = torch._C.fused_linear_with_groupwise_quantized_weight(input_,
w=self.weight,
w_scale=self.weight_scale,
w_zero=self.weight_zero,
b=self.bias if self.bias is not None else None,
num_bits=self.weight_bit_width,
symmetric=self.symmetric,
group_dim=self.group_dim,
group_size=self.group_size)
return output
def quantize_oneflow(model, weight_bit_width):
"""Replace fp16 linear with quantized linear"""
for i in range(len(model.language_model.transformer.layers) + 1):
if i == len(model.language_model.transformer.layers):
layer = model.language_model.transformer.topQueryLayer
else:
layer = model.language_model.transformer.layers[i]
layer.attention.query = QuantizedLinear(
in_features=layer.attention.query.in_features,
out_features=layer.attention.query.out_features,
weight_bit_width=weight_bit_width,
weight=layer.attention.query.weight.to(torch.cuda.current_device()),
bias=layer.attention.query.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.query.weight.device,
)
layer.attention.value = QuantizedLinear(
in_features=layer.attention.value.in_features,
out_features=layer.attention.value.out_features,
weight_bit_width=weight_bit_width,
weight=layer.attention.value.weight.to(torch.cuda.current_device()),
bias=layer.attention.value.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.value.weight.device,
)
layer.attention.key = QuantizedLinear(
in_features=layer.attention.key.in_features,
out_features=layer.attention.key.out_features,
weight_bit_width=weight_bit_width,
weight=layer.attention.key.weight.to(torch.cuda.current_device()),
bias=layer.attention.key.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.key.weight.device,
)
layer.attention.dense = QuantizedLinear(
in_features=layer.attention.dense.in_features,
out_features=layer.attention.dense.out_features,
weight_bit_width=weight_bit_width,
weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
bias=layer.attention.dense.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.dense.weight.device,
)
layer.mlp.dense_h_to_4h = QuantizedLinear(
in_features=layer.mlp.dense_h_to_4h.in_features,
out_features=layer.mlp.dense_h_to_4h.out_features,
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_h_to_4h.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.mlp.dense_h_to_4h.weight.device,
)
layer.mlp.dense_4h_to_h = QuantizedLinear(
in_features=layer.mlp.dense_4h_to_h.in_features,
out_features=layer.mlp.dense_4h_to_h.out_features,
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_4h_to_h.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.mlp.dense_4h_to_h.weight.device,
)
return model

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,39 @@
# This script is used to test the inference of CodeGeeX.
GPU=$1
PROMPT_FILE=$2
SCRIPT_PATH=$(realpath "$0")
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
MAIN_DIR=$(dirname "$SCRIPT_DIR")
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
# import model configuration
source "$MAIN_DIR/configs/codegeex_13b.sh"
# export CUDA settings
if [ -z "$GPU" ]; then
GPU=1
fi
export CUDA_HOME=/usr/local/cuda-11.1/
export CUDA_VISIBLE_DEVICES=$GPU
if [ -z "$PROMPT_FILE" ]; then
PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt
fi
# remove --greedy if using sampling
CMD="python $MAIN_DIR/tests/test_inference_oneflow.py \
--prompt-file $PROMPT_FILE \
--tokenizer-path $TOKENIZER_PATH \
--micro-batch-size 1 \
--out-seq-length 1024 \
--temperature 0.2 \
--top-p 0.95 \
--top-k 0 \
--quantize \
$MODEL_ARGS"
echo "$CMD"
eval "$CMD"

View File

@ -10,8 +10,9 @@ import numpy as np
from codegeex.oneflow.inference import get_token_stream
from codegeex.oneflow import CodeGeeXModel
from codegeex.tokenizer import CodeGeeXTokenizer
from codegeex.quantization import quantize
from codegeex.quantization import quantize_oneflow
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1"
os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1"
def model_provider(args):
"""Build the model."""
@ -135,7 +136,7 @@ def main():
model.eval()
model.half()
if args.quantize:
model = quantize(model, weight_bit_width=8, backend="torch")
model = quantize_oneflow(model, weight_bit_width=8)
model.cuda()
torch.cuda.synchronize()
with open(args.prompt_file, "r") as f: