你的位置:首页 > 信息动态 > 新闻中心
信息动态
联系我们

终结篇:t5模型结构的阅读

2021-11-30 2:11:06

问题关键:past_key_value

    • 模型的整体结构(由外到内)
    • 最外层generation_utils.py之中的greedy_search调用模型解读
    • t5Stack模型的解读
    • t5block网络层中内容解读
    • t5layerselfattention的解读
    • t5layerselfattention+t5layercrossattention中t5layerselfattention代码解读
      • t5attention的第一次运行
      • t5attention encoder 第二次调用
    • t5layerselfattention+t5layercrossattention中t5layerselfattention代码解读
      • 第一次decoder部分的t5layerselfattention代码调用
      • 第二次decoder部分的t5layerselfattention代码调用(这里的第二次为调用了6个encoder的t5layerselfattention以及decoder中的6个encoder的t5layerselfattention和t5layercrossattention内容)
    • t5layerselfattention+t5layercrossattention中t5layercrossattention代码解读
      • 第一次调用t5layercrossattention
      • 第二次调用t5layercrossattention内容

通过之前对于代码的阅读,发现问题的关键就在于past_key_value参数的变化,导致输入不需要那么复杂的输入了,

模型的整体结构(由外到内)

模型的整体结构决定着数据的运转方向
模型整体的框架结构图
模型整体的框架结构图

最外层generation_utils.py之中的greedy_search调用模型解读

while True:
      if synced_gpus:
          # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
          # The following logic allows an early break if all peers finished generating their sequence
          this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
          # send 0.0 if we finished, 1.0 otherwise
          dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
          # did all peers finish? the reduced sum will be 0.0 then
          if this_peer_finished_flag.item() == 0.0:
              break
              
      model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

      # forward pass to get next token
      outputs = self(
          **model_inputs,
          return_dict=True,
          output_attentions=output_attentions,
          output_hidden_states=output_hidden_states,
      )
      if synced_gpus and this_peer_finished:
          cur_len = cur_len + 1
          continue  # don't waste resources running the code we don't need

      next_token_logits = outputs.logits[:, -1, :]

      # Store scores, attentions and hidden_states when required
      if return_dict_in_generate:
          if output_scores:
              scores += (next_token_logits,)
          if output_attentions:
              decoder_attentions += (
                  (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
              )
              if self.config.is_encoder_decoder:
                  cross_attentions += (outputs.cross_attentions,)

          if output_hidden_states:
              decoder_hidden_states += (
                  (outputs.decoder_hidden_states,)
                  if self.config.is_encoder_decoder
                  else (outputs.hidden_states,)
              )

      # pre-process distribution
      next_tokens_scores = logits_processor(input_ids, next_token_logits)

      # argmax
      next_tokens = torch.argmax(next_tokens_scores, dim=-1)

      # finished sentences should have their next token be a padding token
      if eos_token_id is not None:
          if pad_token_id is None:
              raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
          next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

      # update generated ids, model inputs, and length for next step
      input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
      model_kwargs = self._update_model_kwargs_for_generation(
          outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
      )
      cur_len = cur_len + 1

      # if eos_token was found in one sentence, set sentence to finished
      if eos_token_id is not None:
          unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())

      # stop when each sentence is finished, or if we exceed the maximum length
      if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
          if not synced_gpus:
              break
          else:
              this_peer_finished = True

这里的输入部分

input_ids = torch.cat([input_ids,next_tokens[:,None]],dim=-1)

得到input_ids = [0,644]
然后

model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

这里估摸着参数为之前传下来的参数(目前只看了[0][0] = (1,8,1,64))

model_inputs['past_key_value'][0][0] = torch.Size([1, 8, 1, 64])
model_inputs['past_key_value'][0][1] = torch.Size([1, 8, 1, 64])
model_inputs['past_key_value'][1][0] = torch.Size([1, 8, 11, 64])
model_inputs['past_key_value'][1][1] = torch.Size([1, 8, 11, 64])

t5Stack模型的解读

t5stack的定义

def forward(
    self,
    input_ids=None,
    attention_mask=None,
    encoder_hidden_states=None,
    encoder_attention_mask=None,
    inputs_embeds=None,
    head_mask=None,
    cross_attn_head_mask=None,
    past_key_values=None,
    use_cache=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
):

进入t5stack类别之中查看内容

for i,(layer_module,past_key_value) in enumerate(zip(self.block,past_key_values)):
	............
	else:
        layer_outputs = layer_module(
          hidden_states,
          attention_mask=extended_attention_mask,
          position_bias=position_bias,
          encoder_hidden_states=encoder_hidden_states,
          encoder_attention_mask=encoder_extended_attention_mask,
          encoder_decoder_position_bias=encoder_decoder_position_bias,
          layer_head_mask=layer_head_mask,
          cross_attn_layer_head_mask=cross_attn_layer_head_mask,
          past_key_value=past_key_value,
          use_cache=use_cache,
          output_attentions=output_attentions,
       )

初始的时候这里的layer_module读取的是模型,past_key_values中存储的是6个None,到后面由于上一层的greedy_search中的参数不同,所以传入的past_key_values参数不同。
这里面的past_key_value保存的是6个对应的past_key_value内容(第一次全部为None),

past_key_value[0][0] = (1,8,1,64)
past_key_value[0][1] = (1,8,1,64)
past_key_value[0][2] = (1,8,11,64)
past_key_value[0][3] = (1,8,11,64)
............
............
past_key_value[5][0] = (1,8,1,64)
past_key_value[5][1] = (1,8,1,64)
past_key_value[5][2] = (1,8,11,64)
past_key_value[5][3] = (1,8,11,64)

上一次的t5stack留给这次的t5stack使用,都是同一层的在使用
注意,t5stack之中的past_value_state 第一次为 [None,None,None,None,None,None],后续每一次都是上一波的遗留下来的内容
也就是说,后面的t5block网络层中的内容传入的是同时期上一波t5block的输出,比如第二次调用t5block第二层的内容传入的是第一次调用t5block网络层第二层的内容。

t5block网络层中内容解读

进入t5block之中使用

hidden_states,present_key_value_state = self_attention_outputs[:2]

这里传递的是t5layerselfattention网络层之中传播的先前预测的内容(上一个网络结构同层的内容),这也能理解为什么这里刚开始

self_attn_past_key_value = past_key_value[:2]
......
......
self_attention_outputs = self.layer[0](
	......
	past_key_value=self_attn_past_key_value,
	......
)

获得的

self_attn_past_key_value[0][0] = (1,8,1,64)
self_attn_past_key_value[0][1] = (1,8,1,64)
self_attn_past_key_value[0][2] = (1,8,11,64)
self_attn_past_key_value[0][3] = (1,8,11,64)

经过这一波数据输出之后,调用新的present_key_value_state

hidden_states,present_key_value_state = self_attention_outputs[:2]

这里的present_key_value_state的内容为

present_key_value_state[0] = 
torch.Size([1, 8, 1, 64])
present_key_value_state[1] = 
torch.Size([1, 8, 1, 64])

接下来经过decoder部分之后,调用新的present_key_value_state

cross_attention_outputs = self.layer[1](
    hidden_states,
    key_value_states=encoder_hidden_states,
    attention_mask=encoder_attention_mask,
    position_bias=encoder_decoder_position_bias,
    layer_head_mask=cross_attn_layer_head_mask,
    past_key_value=cross_attn_past_key_value,
    query_length=query_length,
    use_cache=use_cache,
    output_attentions=output_attentions,
)

获得的新的present_key_value_state的内容

# Combine self attn and cross attn key value states
if present_key_value_state is not None:
    present_key_value_state = present_key_value_state + cross_attention_outputs[1]

获得新的present_key_value_state的内容为

present_key_value_state = 
torch.Size([1, 8, 1, 64])
torch.Size([1, 8, 1, 64])
torch.Size([1, 8, 11, 64])
torch.Size([1, 8, 11, 64])

另外两个位置的偏移参数也保存在后面

# Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:]

获得的位置偏移的内容为

attention_outputs = 
torch.Size([1, 8, 1, 1])
torch.Size([1, 8, 1, 11])

t5layerselfattention的解读

t5block之中有两种模式,一种是t5layerselfattention的解读,一种是t5layerselfattention+t5layercrossattention网络结构的解读,这里我们解读t5layerselfattention
这里注入的past_key_value的内容应该为

None或者
(1,8,1,64)
(1,8,1,64)

t5layerselfattention+t5layercrossattention中t5layerselfattention代码解读

t5block之中有两种模式,一种是t5layerselfattention的解读,一种是t5layerselfattention+t5layercrossattention网络结构的解读,这里我们解读t5layerselfattention+t5layercrossattention网络结构中t5layerselfattention的代码内容
t5layerselfattention直接进入t5attention的内容之中

t5attention的第一次运行

第一次运行的时候

batch_size = 1,seq_length = 11,key_length = 11

然后我们进入调用的过程

query_states = shape(self.q(hidden_states))

得到

query_states = (1,8,1,64)

(这里的query_states的内容其实是固定的)
接下来进入对于key_states和value_states的操作之中

key_states = project(
    hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
    hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)

进入project函数之中去查看内容

def project(hidden_states, proj_layer, key_value_states, past_key_value):
    """projects hidden states correctly to key/query states"""
    if key_value_states is None:
        # self-attn
        # (batch_size, n_heads, seq_length, dim_per_head)
        hidden_states = shape(proj_layer(hidden_states))
    elif past_key_value is None:
        # cross-attn
        # (batch_size, n_heads, seq_length, dim_per_head)
        hidden_states = shape(proj_layer(key_value_states))

    if past_key_value is not None:
        if key_value_states is None:
            # self-attn
            # (batch_size, n_heads, key_length, dim_per_head)
            hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
        else:
            # cross-attn
            hidden_states = past_key_value
    return hidden_states

这里的key_value_states is None,后面的elif、if语句都没有被调用过,直接调用网络层

hidden_states = shape(proj_layer(hidden_states))

获得的结果

hidden_states = torch.size([1,8,11,64])

接下来调用

key_states = project(
    hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
    hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)

获得的结果

key_states = torch.Size([1, 8, 11, 64])
value_states = torch.Size([1, 8, 11, 64])

然后计算相应的分数

# compute scores
scores = torch.matmul(
    query_states, key_states.transpose(3, 2)
)  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

获得结果

scores = (1,8,11,11)

接下来计算position_bias的内容

if position_bias is None:
    if not self.has_relative_attention_bias:
        position_bias = torch.zeros(
            (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
        )
        if self.gradient_checkpointing and self.training:
            position_bias.requires_grad = True
    else:
        position_bias = self.compute_bias(real_seq_length, key_length)

    # if key and values are already calculated
    # we want only the last query position bias
    if past_key_value is not None:
        position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

    if mask is not None:
        position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)

这里应该运行的内容为

position_bias = self.compute_bias(real_seq_length,key_length)

得到position_bias的形状

position_bias = (1,8,11,11)

接下来的操作的内容

scores += position_bias
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
    scores
)  # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.dropout(
    attn_weights, p=self.dropout, training=self.training
)  # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to
if layer_head_mask is not None:
    attn_weights = attn_weights * layer_head_mask

这里的attn_weights = (1,8,11,11)
然后接下来经历一波输出

attn_output = unshape(torch.matmul(attn_weights,value_states))
attn_output = self.o(attn_output)

attn_weights,包括key_states、value_states以及position_bias都相当于中间过程的参数内容,只有outputs是最终结果的参数的内容
最后将这些内容保存成tulpe输出

present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions:
    outputs = outputs + (attn_weights,)
return outputs

这里计算出来的position_bias第一次为None,之后计算出来会往后传递,节约了模型的运行时间。position_bias在6个encoder中的selflayerattention部分是一样的,在6个decoder中的selflayerattention是一样的,6个decoder中的selfcrossattention中的内容是一样的,selflayerattention和selfcrossattention中的position_bias是不一样的

t5attention encoder 第二次调用

第一次调用完就结束了,在预测过程之中,encoder只调用一次6个对应的t5attention encoder内容,encoder调用完成之后,decoder部分的内容是不断地被调用,直到decoder部分输出预测的停止符号为止

t5layerselfattention+t5layercrossattention中t5layerselfattention代码解读

第一次调用的过程没有之前的t5layerselfattention的调用,decoder_input_ids = (1,1)
这里的decoder_input_ids是一开始就初始化好的输入的参数,与之前的encoder_outputs的内容无关
从t5forconditionalgeneration的类别来看

decoder_outputs = self.decoder(
    input_ids=decoder_input_ids,
    attention_mask=decoder_attention_mask,
    inputs_embeds=decoder_inputs_embeds,
    past_key_values=past_key_values,
    encoder_hidden_states=hidden_states,
    encoder_attention_mask=attention_mask,
    head_mask=decoder_head_mask,
    cross_attn_head_mask=cross_attn_head_mask,
    use_cache=use_cache,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
    return_dict=return_dict,
)

这里的之前encoder输出的部分只有

encoder_hidden_states=hidden_states

调用了之前的hidden_states = (1,11,512),其余的参数都与encoder部分无关
然后进入t5block的类别进行查看

self_attention_outputs = self.layer[0](
    hidden_states,
    attention_mask=attention_mask,
    position_bias=position_bias,
    layer_head_mask=layer_head_mask,
    past_key_value=self_attn_past_key_value,
    use_cache=use_cache,
    output_attentions=output_attentions,
)

也就是说在decoder中的selflayerattention调用始终与之前encoder的输出无关
查看t5block中的decoder部分内容

cross_attention_outputs = self.layer[1](
   hidden_states,
   key_value_states=encoder_hidden_states,
   attention_mask=encoder_attention_mask,
   position_bias=encoder_decoder_position_bias,
   layer_head_mask=cross_attn_layer_head_mask,
   past_key_value=cross_attn_past_key_value,
   query_length=query_length,
   use_cache=use_cache,
   output_attentions=output_attentions,
)

decoder中的crosslayerattention部分会调用之前encoder的输出

key_value_states = encoder_hidden_states

这里我们先查看第一次encoder部分的输出

第一次decoder部分的t5layerselfattention代码调用

刚开始的参数

batch_size,seq_length = hidden_states.shape[:2]
real_seq_length = seq_length

获得的参数

batch_size = 1,seq_length = 1,real_seq_length = 1

接下来这里调用网络层是不变的

query_states = shape(self.q(hidden_states))

获得query_states内容

query_states = torch.Size([1, 8, 1, 64])

然后调用

key_states = project(
    hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
    hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)

获得形状

key_states = torch.tensor([1, 8, 1, 64])
value_states = torch.tensor([1, 8, 1, 64])

后面的程序操作与上面操作类似,最后调用输出内容

outputs = (attn_output,)+(present_key_value_state,)+(position_bias,)

第二次decoder部分的t5layerselfattention代码调用(这里的第二次为调用了6个encoder的t5layerselfattention以及decoder中的6个encoder的t5layerselfattention和t5layercrossattention内容)

这里的第二次相当于预测完第一个数值之后,第二次运行到新的位置。这里调用的past_key_value[0]相当于上一个位置同一层输出的key_states,past_key_value[1]相当于上一个位置同一层输出的value_states(比如这里是第二波6个encoder+3个decoder+第4个decoder的selflayerattention,那么前面就相当于第一波的6个encoder+3个decoder+第4个decoder的selflayerattention的内容)
接下来进入

key_states = project(
    hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
    hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)
if past_key_value is not None:
    if key_value_states is None:
        # self-attn
        # (batch_size, n_heads, key_length, dim_per_head)
        hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
    else:
        # cross-attn
        hidden_states = past_key_value

这里如果是t5layerselfattention的时候会调用第一个if,如果是crossattention的时候会调用第二个if
如果为t5layerselfattention的时候,在project函数里面会调用如下代码

if past_key_value is not None:
    if key_value_states is None:
        # self-attn
        # (batch_size, n_heads, key_length, dim_per_head)
        hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
        ............
return hidden_states

获得第二波中的输出内容

key_states.size = torch.Size([1, 8, 2, 64])
value_states.size = torch.Size([1, 8, 2, 64])

接下来调用scores内容

# compute scores
scores = torch.matmul(
    query_states, key_states.transpose(3, 2)
)  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

获得的结果

scores = torch.Size([1, 8, 1, 2])

接下来查看position_bias的计算

if position_bias is None:
     if not self.has_relative_attention_bias:
         position_bias = torch.zeros(
             (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
         )
         if self.gradient_checkpointing and self.training:
             position_bias.requires_grad = True
     else:
         position_bias = self.compute_bias(real_seq_length, key_length)

这里得到的position_bias的结果

position_bias = torch.Size([1, 8, 2, 2])

接下来的操作,有对应的一行小字标注:

if key and values are already calculated,
we want only the last query position bias.

调用对应的代码

if past_key_value is not None:
   position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

注意取出来的是最后面的一维,取出来之后,position_bias = (1,8,1,2)
然后调用语句

scores += position_bias
#scores = (1,8,1,2)
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
    scores
)  # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.dropout(
    attn_weights, p=self.dropout, training=self.training
)  # (batch_size, n_heads, seq_length, key_length)

# Mask heads if we want to
if layer_head_mask is not None:
    attn_weights = attn_weights * layer_head_mask

到这为止scores的内容都为(1,8,1,2)
接下来调用

attn_output = unshape((torch.matmul(attn_weights,value_states))

attn_weights = (1,8,1,2),value_states = (1,8,2,64)
相乘之后得到结果(1,8,1,64)
然后使用unshape之后进行输出

attn_output = unshape(torch.matmul(attn_weights,value_states))
#attn_output = (1,1,512)
attn_output = self.o(attn_output)

获得结果

attn_output = (1,1,512)

t5layerselfattention+t5layercrossattention中t5layercrossattention代码解读

t5block之中有两种模式,一种是t5layerselfattention的解读,一种是t5layerselfattention+t5layercrossattention网络结构的解读,这里我们解读t5layerselfattention+t5layercrossattention网络结构中t5layercrossattention的代码内容

第一次调用t5layercrossattention

前面的参数跟selflayerattention差不多

batch_size = 1,seq_length = 1,real_seq_length = 1

接着调用语句

key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]

这里由于key_value_states不为None,所以这里获取得到的是

key_length = 11

这里的key_value_states = (1,11,512),就是之前6个encoder输出的内容的结果(6个t5layercrossattention的结果一样)
接下来调用project映射部分的内容

def project(hidden_states, proj_layer, key_value_states, past_key_value):
    """projects hidden states correctly to key/query states"""
    if key_value_states is None:
        # self-attn
        # (batch_size, n_heads, seq_length, dim_per_head)
        hidden_states = shape(proj_layer(hidden_states))
    elif past_key_value is None:
        # cross-attn
        # (batch_size, n_heads, seq_length, dim_per_head)
        hidden_states = shape(proj_layer(key_value_states))

    if past_key_value is not None:
        if key_value_states is None:
            # self-attn
            # (batch_size, n_heads, key_length, dim_per_head)
            hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
        else:
            # cross-attn
            hidden_states = past_key_value
    return hidden_states

第一波layercrossattention直接调用上面的这一语句

elif past_key_value is None:
	hidden_states = shape(proj_layer(key_value_states))

这里的key_value_states的内容为之前encoder输出的部分(1,8,11,64)
因此这里的hidden_states = (1,8,11,64)
然后past_key_value == None后面的if语句没有调用
接下来调用

query_states = shape(self.q(hidden_states))

query_states = (1,8,1,64)
然后后面两个调用

key_states = project(
    hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
    hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)

得到

key_states = (1,8,11,64)
value_states = (1,8,11,64)

然后scores调用中间过程

scores = torch.matmul(query_states,key_states.transpose(3,2))

得到结果

scores = (1,8,1,64)*(1,8,64,11) = (1,8,1,11)

接着调用后续的语句

scores += position_bias
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
    scores
)  # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.dropout(
    attn_weights, p=self.dropout, training=self.training
)  # (batch_size, n_heads, seq_length, key_length)

# Mask heads if we want to
if layer_head_mask is not None:
    attn_weights = attn_weights * layer_head_mask    

attn_weights = (1,8,1,11)
最后相乘并返回

attn_output = unshape(torch.matmul(attn_weights,value_states))
attn_output= self.o(attn_output)

获得结果

attn_output = (1,8,1,11)*(1,8,11,64) = (1,8,1,64)->(1,1,512)
attn_output经过线性层之后->(1,1,512)

最后把这些参数都打包在一起进行输出

present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions:
    outputs = outputs + (attn_weights,)

第二次调用t5layercrossattention内容

刚开始调用的参数一样

batch_size,seq_length = hidden_states.shape[:2]
real_seq_length = seq_length

这里的batch_size = 1,seq_length = 1,real_seq_length = 1
接着调用

key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]

获得参数

key_length = 11

唯一的区别就在于key_states和value_states的调用过程不一样

key_states = project(
    hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
    hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)

首先,这里传入的past_key_value[0]和past_key_value[1]为同一层次上一波的运行的结果
这里调用的past_key_value[0]相当于上一个位置同一层输出的key_states,past_key_value[1]相当于上一个位置同一层输出的value_states(比如这里是第二波6个encoder+3个decoder+第4个decoder的selflayerattention,那么前面就相当于第一波的6个encoder+3个decoder+第4个decoder的selflayerattention的内容)
接下来进入project函数之中

def project(hidden_states, proj_layer, key_value_states, past_key_value):
    """projects hidden states correctly to key/query states"""
    if key_value_states is None:
        # self-attn
        # (batch_size, n_heads, seq_length, dim_per_head)
        hidden_states = shape(proj_layer(hidden_states))
    elif past_key_value is None:
        # cross-attn
        # (batch_size, n_heads, seq_length, dim_per_head)
        hidden_states = shape(proj_layer(key_value_states))

    if past_key_value is not None:
        if key_value_states is None:
            # self-attn
            # (batch_size, n_heads, key_length, dim_per_head)
            hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
        else:
            # cross-attn
            hidden_states = past_key_value
    return hidden_states

直接运行最后一个else

hidden_states = past_key_value

获得hidden_states = torch.Size([1, 8, 11, 64])
总结一下project函数中的内容为,第一个if针对第一次的selflayerattention(包括encoder和decoder部分),else针对第一次的layercrossattention,第二个if针对第二次~第n次的selflayerattention,else针对第二次~第n次的layercrossattention
后续的操作都差不多

(1,8,1,64)*(1,8,64,11) = (1,8,1,11)
(1,8,1,11)*(1,8,11,64) = (1,8,1,64)