optimize and support int8
This commit is contained in:
parent
e97663b388
commit
63e77fae38
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -2,7 +2,7 @@ import math
|
|||
import oneflow as torch
|
||||
import oneflow.nn.functional as F
|
||||
from oneflow.nn.parameter import Parameter
|
||||
|
||||
from ..quantization import QuantizedLinear
|
||||
|
||||
def fast_gelu(x):
|
||||
"""Mindspore's fast gelu implementation."""
|
||||
|
@ -13,7 +13,6 @@ def fast_gelu(x):
|
|||
|
||||
class MLP(torch.nn.Module):
|
||||
"""MLP.
|
||||
|
||||
MLP will take the input with h hidden state, project it to 4*h
|
||||
hidden dimension, perform nonlinear transformation, and project the
|
||||
state back into h hidden dimension. At the end, dropout is also
|
||||
|
@ -52,7 +51,6 @@ class MLP(torch.nn.Module):
|
|||
|
||||
class SelfAttention(torch.nn.Module):
|
||||
"""self-attention layer abstract class.
|
||||
|
||||
Self-attention layer takes input with size [b, s, h]
|
||||
and returns output of the same size.
|
||||
"""
|
||||
|
@ -100,7 +98,7 @@ class SelfAttention(torch.nn.Module):
|
|||
# Query, Key, and Value
|
||||
# =====================
|
||||
|
||||
if hasattr(torch._C, 'grouped_matmul_bias'):
|
||||
if hasattr(torch._C, 'grouped_matmul_bias') and not isinstance(self.query, QuantizedLinear):
|
||||
query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias([hidden_states, hidden_states, hidden_states],
|
||||
[self.query.weight, self.key.weight, self.value.weight],
|
||||
[self.query.bias, self.key.bias, self.value.bias])
|
||||
|
@ -108,49 +106,41 @@ class SelfAttention(torch.nn.Module):
|
|||
query_layer = self.query(hidden_states)
|
||||
key_layer = self.key(hidden_states)
|
||||
value_layer = self.value(hidden_states)
|
||||
|
||||
new_query_layer_shape = query_layer.size()[:-1] + \
|
||||
(self.num_attention_heads,
|
||||
self.hidden_size_per_attention_head)
|
||||
query_layer = query_layer.view(*new_query_layer_shape)
|
||||
|
||||
new_query_layer_shape = key_layer.size()[:-1] + \
|
||||
(self.num_attention_heads,
|
||||
self.hidden_size_per_attention_head)
|
||||
key_layer = key_layer.view(*new_query_layer_shape)
|
||||
|
||||
new_query_layer_shape = value_layer.size()[:-1] + \
|
||||
(self.num_attention_heads,
|
||||
self.hidden_size_per_attention_head)
|
||||
value_layer = value_layer.view(*new_query_layer_shape)
|
||||
|
||||
# ==================================
|
||||
# Adjust key and value for inference
|
||||
# ==================================
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key_layer = torch.cat((past_key.type_as(key_layer),
|
||||
key_layer), dim=0)
|
||||
value_layer = torch.cat((past_value.type_as(value_layer),
|
||||
value_layer), dim=0)
|
||||
if get_key_value:
|
||||
present = (key_layer, value_layer)
|
||||
|
||||
origin_query_layer = query_layer
|
||||
origin_key_layer = key_layer
|
||||
origin_value_layer = value_layer
|
||||
fallback = not hasattr(torch._C, 'fused_multi_head_attention_inference_v2')
|
||||
|
||||
if hasattr(torch._C, 'fused_multi_head_attention_inference'):
|
||||
if layer_past is not None:
|
||||
context_layer = torch._C.fused_multi_head_attention_inference(
|
||||
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=False
|
||||
).transpose(0, 1)
|
||||
if fallback:
|
||||
if hasattr(torch._C, 'fused_codegeex_qkv_reshape'):
|
||||
query_layer, key_layer, value_layer = torch._C.fused_codegeex_qkv_reshape(query_layer, key_layer, value_layer, self.num_attention_heads)
|
||||
else:
|
||||
context_layer = torch._C.fused_multi_head_attention_inference(
|
||||
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=True
|
||||
).transpose(0, 1)
|
||||
else:
|
||||
new_query_layer_shape = query_layer.size()[:-1] + \
|
||||
(self.num_attention_heads,
|
||||
self.hidden_size_per_attention_head)
|
||||
query_layer = query_layer.view(*new_query_layer_shape)
|
||||
|
||||
new_query_layer_shape = key_layer.size()[:-1] + \
|
||||
(self.num_attention_heads,
|
||||
self.hidden_size_per_attention_head)
|
||||
key_layer = key_layer.view(*new_query_layer_shape)
|
||||
|
||||
new_query_layer_shape = value_layer.size()[:-1] + \
|
||||
(self.num_attention_heads,
|
||||
self.hidden_size_per_attention_head)
|
||||
value_layer = value_layer.view(*new_query_layer_shape)
|
||||
|
||||
# ==================================
|
||||
# Adjust key and value for inference
|
||||
# ==================================
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key_layer = torch.cat((past_key.type_as(key_layer),
|
||||
key_layer), dim=0)
|
||||
value_layer = torch.cat((past_value.type_as(value_layer),
|
||||
value_layer), dim=0)
|
||||
if get_key_value:
|
||||
present = (key_layer, value_layer)
|
||||
|
||||
# ===================================
|
||||
# Raw attention scores. [b, np, sq, sk]
|
||||
# ===================================
|
||||
|
@ -167,7 +157,7 @@ class SelfAttention(torch.nn.Module):
|
|||
|
||||
# Raw attention scores. [b * np, sq, sk]
|
||||
matmul_result = torch.matmul(query_layer.transpose(0, 1),
|
||||
key_layer.permute(1, 2, 0)) / self.norm_factor
|
||||
key_layer.transpose(0, 1).transpose(1, 2)) / self.norm_factor
|
||||
|
||||
# change view to [b, np, sq, sk]
|
||||
attention_scores = matmul_result.view(*output_size)
|
||||
|
@ -194,11 +184,22 @@ class SelfAttention(torch.nn.Module):
|
|||
attention_mask = torch.clone(attention_mask)
|
||||
attention_mask[:, :, context_length:, :] = True
|
||||
|
||||
attention_scores = attention_scores - attention_mask * 10000.0
|
||||
if self.attention_softmax_in_fp32:
|
||||
attention_probs = self.softmax(attention_scores.float()).half()
|
||||
attention_mask = ~attention_mask
|
||||
attention_mask = attention_mask.contiguous()
|
||||
|
||||
# attention scores and attention mask [b, np, sq, sk]
|
||||
# attention_scores = attention_mask_func(attention_scores, attention_mask)
|
||||
if hasattr(torch._C, 'fused_scale_mask_softmax'):
|
||||
if self.attention_softmax_in_fp32:
|
||||
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores.float(), attention_mask, fill_value=-10000.0, scale=1.0).half()
|
||||
else:
|
||||
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores, attention_mask, fill_value=-10000.0, scale=1.0)
|
||||
else:
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
attention_scores = attention_scores - attention_mask * 10000.0
|
||||
if self.attention_softmax_in_fp32:
|
||||
attention_probs = self.softmax(attention_scores.float()).half()
|
||||
else:
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
|
||||
# =========================
|
||||
# Context layer. [sq, b, hp]
|
||||
|
@ -220,7 +221,7 @@ class SelfAttention(torch.nn.Module):
|
|||
attention_probs = attention_probs.view(output_size[0] * output_size[1],
|
||||
output_size[2], -1)
|
||||
|
||||
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
|
||||
context_layer = torch.bmm(attention_probs, value_layer.unsqueeze(0).transpose(1, 2).squeeze(0))
|
||||
|
||||
# change view [b, np, sq, hn]
|
||||
context_layer = context_layer.view(*output_size)
|
||||
|
@ -232,10 +233,40 @@ class SelfAttention(torch.nn.Module):
|
|||
new_context_layer_shape = context_layer.size()[:-2] + \
|
||||
(self.hidden_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
else:
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key_layer, value_layer = torch._C.fused_attention_concat_past_key_value(
|
||||
past_key=past_key,
|
||||
past_key_layout="MB(HK)",
|
||||
past_value=past_value,
|
||||
past_value_layout="MB(HK)",
|
||||
key=key_layer,
|
||||
key_layout="MB(HK)",
|
||||
value=value_layer,
|
||||
value_layout="MB(HK)",
|
||||
key_head_size=self.hidden_size_per_attention_head,
|
||||
)
|
||||
if get_key_value:
|
||||
present = (key_layer, value_layer)
|
||||
|
||||
context_layer = torch._C.fused_multi_head_attention_inference_v2(
|
||||
query=query_layer,
|
||||
key=key_layer,
|
||||
value=value_layer,
|
||||
query_head_size=self.hidden_size_per_attention_head,
|
||||
causal=True,
|
||||
causal_diagonal_offset=key_layer.shape[0]-query_layer.shape[0],
|
||||
query_layout="MB(HK)",
|
||||
key_layout="MB(HK)",
|
||||
value_layout="MB(HK)",
|
||||
output_layout="MB(HK)",
|
||||
)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
# =================
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
# =================
|
||||
|
||||
output = self.dense(context_layer)
|
||||
|
||||
|
@ -247,7 +278,6 @@ class SelfAttention(torch.nn.Module):
|
|||
|
||||
class TopQuerySelfAttention(torch.nn.Module):
|
||||
"""Top query self-attention layer abstract class.
|
||||
|
||||
Self-attention layer takes input with size [b, s, h]
|
||||
and returns output of the same size.
|
||||
"""
|
||||
|
@ -291,7 +321,7 @@ class TopQuerySelfAttention(torch.nn.Module):
|
|||
):
|
||||
|
||||
# hidden_states: [sq, b, h]
|
||||
if hasattr(torch._C, 'grouped_matmul_bias'):
|
||||
if hasattr(torch._C, 'grouped_matmul_bias') and not isinstance(self.query, QuantizedLinear):
|
||||
query_layer, key_layer, value_layer = torch._C.grouped_matmul_bias([query_hidden_state, hidden_states, hidden_states],
|
||||
[self.query.weight, self.key.weight, self.value.weight],
|
||||
[self.query.bias, self.key.bias, self.value.bias])
|
||||
|
@ -299,49 +329,41 @@ class TopQuerySelfAttention(torch.nn.Module):
|
|||
query_layer = self.query(query_hidden_state)
|
||||
key_layer = self.key(hidden_states)
|
||||
value_layer = self.value(hidden_states)
|
||||
|
||||
fallback = not hasattr(torch._C, 'fused_multi_head_attention_inference_v2')
|
||||
|
||||
new_query_layer_shape = query_layer.size()[:-1] + \
|
||||
(self.num_attention_heads,
|
||||
self.hidden_size_per_attention_head)
|
||||
query_layer = query_layer.view(*new_query_layer_shape)
|
||||
|
||||
new_query_layer_shape = key_layer.size()[:-1] + \
|
||||
(self.num_attention_heads,
|
||||
self.hidden_size_per_attention_head)
|
||||
key_layer = key_layer.view(*new_query_layer_shape)
|
||||
|
||||
new_query_layer_shape = value_layer.size()[:-1] + \
|
||||
(self.num_attention_heads,
|
||||
self.hidden_size_per_attention_head)
|
||||
value_layer = value_layer.view(*new_query_layer_shape)
|
||||
|
||||
# ==================================
|
||||
# Adjust key and value for inference
|
||||
# ==================================
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key_layer = torch.cat((past_key.type_as(key_layer),
|
||||
key_layer), dim=0)
|
||||
value_layer = torch.cat((past_value.type_as(value_layer),
|
||||
value_layer), dim=0)
|
||||
if get_key_value:
|
||||
present = (key_layer, value_layer)
|
||||
|
||||
origin_query_layer = query_layer
|
||||
origin_key_layer = key_layer
|
||||
origin_value_layer = value_layer
|
||||
|
||||
if hasattr(torch._C, 'fused_multi_head_attention_inference'):
|
||||
if layer_past is not None:
|
||||
context_layer = torch._C.fused_multi_head_attention_inference(
|
||||
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=False
|
||||
).transpose(0, 1)
|
||||
if fallback:
|
||||
if hasattr(torch._C, 'fused_codegeex_qkv_reshape'):
|
||||
query_layer, key_layer, value_layer = torch._C.fused_codegeex_qkv_reshape(query_layer, key_layer, value_layer, self.num_attention_heads)
|
||||
else:
|
||||
context_layer = torch._C.fused_multi_head_attention_inference(
|
||||
origin_query_layer.view(query_layer.size()[1], query_layer.size()[0], -1), origin_key_layer.view(key_layer.size()[1], key_layer.size()[0], -1), origin_value_layer.view(value_layer.size()[1], value_layer.size()[0], -1), self.num_attention_heads, causal=True
|
||||
).transpose(0, 1)
|
||||
else:
|
||||
new_query_layer_shape = query_layer.size()[:-1] + \
|
||||
(self.num_attention_heads,
|
||||
self.hidden_size_per_attention_head)
|
||||
query_layer = query_layer.view(*new_query_layer_shape)
|
||||
|
||||
new_query_layer_shape = key_layer.size()[:-1] + \
|
||||
(self.num_attention_heads,
|
||||
self.hidden_size_per_attention_head)
|
||||
key_layer = key_layer.view(*new_query_layer_shape)
|
||||
|
||||
new_query_layer_shape = value_layer.size()[:-1] + \
|
||||
(self.num_attention_heads,
|
||||
self.hidden_size_per_attention_head)
|
||||
value_layer = value_layer.view(*new_query_layer_shape)
|
||||
|
||||
# ==================================
|
||||
# Adjust key and value for inference
|
||||
# ==================================
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key_layer = torch.cat((past_key.type_as(key_layer),
|
||||
key_layer), dim=0)
|
||||
value_layer = torch.cat((past_value.type_as(value_layer),
|
||||
value_layer), dim=0)
|
||||
if get_key_value:
|
||||
present = (key_layer, value_layer)
|
||||
|
||||
# ===================================
|
||||
# Raw attention scores. [b, np, sq, sk]
|
||||
# ===================================
|
||||
|
@ -386,18 +408,11 @@ class TopQuerySelfAttention(torch.nn.Module):
|
|||
|
||||
# attention scores and attention mask [b, np, sq, sk]
|
||||
# attention_scores = attention_mask_func(attention_scores, attention_mask)
|
||||
if hasattr(torch._C, 'fused_scale_mask_softmax'):
|
||||
attention_mask = ~attention_mask
|
||||
if self.attention_softmax_in_fp32:
|
||||
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores.float(), attention_mask, fill_value=-10000.0, scale=1.0).half()
|
||||
else:
|
||||
attention_probs = torch._C.fused_scale_mask_softmax(attention_scores, attention_mask, fill_value=-10000.0, scale=1.0)
|
||||
attention_scores = attention_scores - attention_mask * 10000.0
|
||||
if self.attention_softmax_in_fp32:
|
||||
attention_probs = self.softmax(attention_scores.float()).half()
|
||||
else:
|
||||
attention_scores = attention_scores - attention_mask * 10000.0
|
||||
if self.attention_softmax_in_fp32:
|
||||
attention_probs = self.softmax(attention_scores.float()).half()
|
||||
else:
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
|
||||
# =========================
|
||||
# Context layer. [sq, b, hp]
|
||||
|
@ -433,9 +448,40 @@ class TopQuerySelfAttention(torch.nn.Module):
|
|||
(self.hidden_size,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
# =================
|
||||
else:
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key_layer, value_layer = torch._C.fused_attention_concat_past_key_value(
|
||||
past_key=past_key,
|
||||
past_key_layout="MB(HK)",
|
||||
past_value=past_value,
|
||||
past_value_layout="MB(HK)",
|
||||
key=key_layer,
|
||||
key_layout="MB(HK)",
|
||||
value=value_layer,
|
||||
value_layout="MB(HK)",
|
||||
key_head_size=self.hidden_size_per_attention_head,
|
||||
)
|
||||
if get_key_value:
|
||||
present = (key_layer, value_layer)
|
||||
|
||||
if hasattr(torch._C, 'fused_multi_head_attention_inference_v2'):
|
||||
context_layer = torch._C.fused_multi_head_attention_inference_v2(
|
||||
query=query_layer,
|
||||
key=key_layer,
|
||||
value=value_layer,
|
||||
query_head_size=self.hidden_size_per_attention_head,
|
||||
causal=True,
|
||||
causal_diagonal_offset=key_layer.shape[0]-query_layer.shape[0],
|
||||
query_layout="MB(HK)",
|
||||
key_layout="MB(HK)",
|
||||
value_layout="MB(HK)",
|
||||
output_layout="MB(HK)",
|
||||
)
|
||||
|
||||
# =================
|
||||
# Output. [sq, b, h]
|
||||
# =================
|
||||
|
||||
output = self.dense(context_layer)
|
||||
|
||||
|
@ -447,7 +493,6 @@ class TopQuerySelfAttention(torch.nn.Module):
|
|||
|
||||
class TransformerLayer(torch.nn.Module):
|
||||
"""A single transformer layer.
|
||||
|
||||
Transformore layer takes input with size [b, s, h] and returns an
|
||||
output of the same size.
|
||||
"""
|
||||
|
@ -527,7 +572,6 @@ class TransformerLayer(torch.nn.Module):
|
|||
|
||||
class TopQueryLayer(torch.nn.Module):
|
||||
"""A single top query layer.
|
||||
|
||||
Top query layer takes input with size [b, s, h] and returns an
|
||||
output of the same size.
|
||||
"""
|
||||
|
@ -728,7 +772,6 @@ class Transformer(torch.nn.Module):
|
|||
|
||||
class Embedding(torch.nn.Module):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
|
@ -808,7 +851,6 @@ class Embedding(torch.nn.Module):
|
|||
|
||||
class QueryEmbedding(torch.nn.Module):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
|
@ -868,7 +910,6 @@ class QueryEmbedding(torch.nn.Module):
|
|||
|
||||
class TransformerLanguageModel(torch.nn.Module):
|
||||
"""Transformer language model.
|
||||
|
||||
Arguments:
|
||||
transformer_hparams: transformer hyperparameters
|
||||
attention_mask_func: a function that takes `unmaksed-attention-scores`
|
||||
|
|
|
@ -1 +1,3 @@
|
|||
from .quantize import quantize
|
||||
from .quantize import quantize
|
||||
from .quantize_oneflow import quantize_oneflow
|
||||
from .quantize_oneflow import QuantizedLinear
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,168 @@
|
|||
import numpy as np
|
||||
import oneflow as torch
|
||||
from oneflow.nn.parameter import Parameter
|
||||
|
||||
def _pack_int8_to_int4(x):
|
||||
np_x = x.numpy()
|
||||
l = np_x[..., 0::2]
|
||||
r = np_x[..., 1::2]
|
||||
l = np.left_shift(l, 4)
|
||||
if x.dtype is np.int8:
|
||||
even = np.bitwise_and(r, np.int8(0xF))
|
||||
packed = torch.tensor(np.bitwise_or(l, r), device=x.device)
|
||||
return packed
|
||||
|
||||
|
||||
def _quantize(num_bits, symmetric, x, group_dim, group_size, quant_type):
|
||||
x_float = x.float()
|
||||
x_reshaped = x_float.reshape(
|
||||
x.shape[:group_dim]
|
||||
+ (x.shape[group_dim] // group_size, group_size)
|
||||
+ x.shape[group_dim + 1 :]
|
||||
)
|
||||
if symmetric:
|
||||
signed_max = float(2 ** (num_bits - 1)) - 1
|
||||
offset = signed_max if quant_type is torch.uint8 else 0.0
|
||||
scale_float = (
|
||||
x_reshaped.abs().max(dim=group_dim + 1, keepdim=True).values / signed_max
|
||||
)
|
||||
quantized = (
|
||||
torch.round(x_reshaped / scale_float + offset)
|
||||
.reshape(x.shape)
|
||||
.to(quant_type)
|
||||
)
|
||||
if num_bits == 4:
|
||||
quantized = _pack_int8_to_int4(quantized)
|
||||
return (quantized, scale_float.squeeze(group_dim + 1).to(x.dtype), None)
|
||||
else:
|
||||
unsigned_max = float(2 ** num_bits) - 1
|
||||
mn = x_reshaped.min(dim=group_dim + 1, keepdim=True).values
|
||||
mx = x_reshaped.max(dim=group_dim + 1, keepdim=True).values
|
||||
scale_float = (mx - mn) / unsigned_max
|
||||
quantized = (
|
||||
torch.round((x_reshaped - mn) / scale_float).reshape(x.shape).to(torch.uint8)
|
||||
)
|
||||
if num_bits == 4:
|
||||
quantized = _pack_int8_to_int4(quantized)
|
||||
return (
|
||||
quantized,
|
||||
scale_float.squeeze(group_dim + 1).to(x.dtype),
|
||||
mn.squeeze(group_dim + 1).to(x.dtype),
|
||||
)
|
||||
|
||||
class QuantizedLinear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
weight_bit_width: int,
|
||||
weight: torch.Tensor = None,
|
||||
bias: torch.Tensor = None,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super(QuantizedLinear, self).__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight_bit_width = weight_bit_width
|
||||
self.symmetric = True
|
||||
self.group_dim = 1
|
||||
self.group_size = in_features
|
||||
|
||||
self.weight, self.weight_scale, self.weight_zero = _quantize(
|
||||
self.weight_bit_width, self.symmetric, weight, self.group_dim, self.group_size, torch.int8
|
||||
)
|
||||
if bias is None:
|
||||
self.register_parameter('bias', None)
|
||||
else:
|
||||
self.bias = bias
|
||||
self.bias = self.bias.to(kwargs["device"])
|
||||
|
||||
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
|
||||
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)
|
||||
if self.bias is not None:
|
||||
self.bias = Parameter(self.bias.to(kwargs["device"]), requires_grad=False)
|
||||
if self.weight_zero is not None:
|
||||
self.weight_zero = Parameter(self.weight_zero.to(kwargs["device"]), requires_grad=False)
|
||||
|
||||
def forward(self, input_):
|
||||
# Matrix multiply.
|
||||
output = torch._C.fused_linear_with_groupwise_quantized_weight(input_,
|
||||
w=self.weight,
|
||||
w_scale=self.weight_scale,
|
||||
w_zero=self.weight_zero,
|
||||
b=self.bias if self.bias is not None else None,
|
||||
num_bits=self.weight_bit_width,
|
||||
symmetric=self.symmetric,
|
||||
group_dim=self.group_dim,
|
||||
group_size=self.group_size)
|
||||
|
||||
return output
|
||||
|
||||
def quantize_oneflow(model, weight_bit_width):
|
||||
"""Replace fp16 linear with quantized linear"""
|
||||
|
||||
for i in range(len(model.language_model.transformer.layers) + 1):
|
||||
if i == len(model.language_model.transformer.layers):
|
||||
layer = model.language_model.transformer.topQueryLayer
|
||||
else:
|
||||
layer = model.language_model.transformer.layers[i]
|
||||
|
||||
layer.attention.query = QuantizedLinear(
|
||||
in_features=layer.attention.query.in_features,
|
||||
out_features=layer.attention.query.out_features,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.query.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.query.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.query.weight.device,
|
||||
)
|
||||
layer.attention.value = QuantizedLinear(
|
||||
in_features=layer.attention.value.in_features,
|
||||
out_features=layer.attention.value.out_features,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.value.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.value.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.value.weight.device,
|
||||
)
|
||||
layer.attention.key = QuantizedLinear(
|
||||
in_features=layer.attention.key.in_features,
|
||||
out_features=layer.attention.key.out_features,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.key.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.key.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.key.weight.device,
|
||||
)
|
||||
layer.attention.dense = QuantizedLinear(
|
||||
in_features=layer.attention.dense.in_features,
|
||||
out_features=layer.attention.dense.out_features,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.attention.dense.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.attention.dense.weight.device,
|
||||
)
|
||||
layer.mlp.dense_h_to_4h = QuantizedLinear(
|
||||
in_features=layer.mlp.dense_h_to_4h.in_features,
|
||||
out_features=layer.mlp.dense_h_to_4h.out_features,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.mlp.dense_h_to_4h.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.mlp.dense_h_to_4h.weight.device,
|
||||
)
|
||||
layer.mlp.dense_4h_to_h = QuantizedLinear(
|
||||
in_features=layer.mlp.dense_4h_to_h.in_features,
|
||||
out_features=layer.mlp.dense_4h_to_h.out_features,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
|
||||
bias=layer.mlp.dense_4h_to_h.bias.to(torch.cuda.current_device()),
|
||||
params_dtype=torch.half,
|
||||
device=layer.mlp.dense_4h_to_h.weight.device,
|
||||
)
|
||||
|
||||
|
||||
return model
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,39 @@
|
|||
# This script is used to test the inference of CodeGeeX.
|
||||
|
||||
GPU=$1
|
||||
PROMPT_FILE=$2
|
||||
|
||||
SCRIPT_PATH=$(realpath "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT_PATH")
|
||||
MAIN_DIR=$(dirname "$SCRIPT_DIR")
|
||||
TOKENIZER_PATH="$MAIN_DIR/codegeex/tokenizer/"
|
||||
|
||||
# import model configuration
|
||||
source "$MAIN_DIR/configs/codegeex_13b.sh"
|
||||
|
||||
# export CUDA settings
|
||||
if [ -z "$GPU" ]; then
|
||||
GPU=1
|
||||
fi
|
||||
|
||||
export CUDA_HOME=/usr/local/cuda-11.1/
|
||||
export CUDA_VISIBLE_DEVICES=$GPU
|
||||
|
||||
if [ -z "$PROMPT_FILE" ]; then
|
||||
PROMPT_FILE=$MAIN_DIR/tests/test_prompt.txt
|
||||
fi
|
||||
|
||||
# remove --greedy if using sampling
|
||||
CMD="python $MAIN_DIR/tests/test_inference_oneflow.py \
|
||||
--prompt-file $PROMPT_FILE \
|
||||
--tokenizer-path $TOKENIZER_PATH \
|
||||
--micro-batch-size 1 \
|
||||
--out-seq-length 1024 \
|
||||
--temperature 0.2 \
|
||||
--top-p 0.95 \
|
||||
--top-k 0 \
|
||||
--quantize \
|
||||
$MODEL_ARGS"
|
||||
|
||||
echo "$CMD"
|
||||
eval "$CMD"
|
|
@ -10,8 +10,9 @@ import numpy as np
|
|||
from codegeex.oneflow.inference import get_token_stream
|
||||
from codegeex.oneflow import CodeGeeXModel
|
||||
from codegeex.tokenizer import CodeGeeXTokenizer
|
||||
from codegeex.quantization import quantize
|
||||
from codegeex.quantization import quantize_oneflow
|
||||
os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1"
|
||||
os.environ["ONEFLOW_LINEAR_EMBEDDING_SKIP_INIT"] = "1"
|
||||
|
||||
def model_provider(args):
|
||||
"""Build the model."""
|
||||
|
@ -135,7 +136,7 @@ def main():
|
|||
model.eval()
|
||||
model.half()
|
||||
if args.quantize:
|
||||
model = quantize(model, weight_bit_width=8, backend="torch")
|
||||
model = quantize_oneflow(model, weight_bit_width=8)
|
||||
model.cuda()
|
||||
torch.cuda.synchronize()
|
||||
with open(args.prompt_file, "r") as f:
|
||||
|
|
Loading…
Reference in New Issue