你的位置:首页 > 信息动态 > 新闻中心
信息动态
联系我们

TorchScript的简介以及项目踩坑

2021/12/13 8:46:19

最近做的项目需要提升模型性能,缩短运行时间。发现TorchScript这么个玩意,之前一直没用过,所以就拿来试试。

简介

使用了JIT及时编译器,为PyTorch创建可序列化和可优化的模型。以额外的开发工作为代价换来一些模型运行性能。

TorchScript可以看成是Python的一个子集,只支持特定的操作和数据类型。普通的列表、字典等是不能够使用的。可以使用的大部分来自typing包(比如typing.List, typing.Dict),见下图。
在这里插入图片描述
里面说字典中键的数据类型必须是str、int或者float,对值没有要求。但是我在项目中使用的一个字典中值是深度学习的模型,也无法使用,不知道是怎么回事。

默认情况下,TorchScript函数的所有参数都被假定为张量Tensor。所以,TorchScript对变量有类型指定比较严格。部分深度学习模型本身已经有实现好的TorchScript版本,比如说BERT,调用时候加个torchscript参数就好.

import torch
from transformers import BertModel

model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)

但是RoBERTa就没有。我从网上复制了huggingface中RoBERTa的代码,然后做了type specifying。但是在初始化模型的时候仍然失败。报错如下。

Unsupported annotation typing.Union[~T, NoneType] could not be resolved because None could not be resolved.

目前还不知道是什么原因。

优势

  • 网上测评大概能缩短10%的运行时间(未亲自测试)

劣势

  • 算是比较新的库,很多功能没有实现或完善
  • 相关资料不多,为开发造成一定困难

实战

应用起来倒不是很费事。初始化的时候在需要的模型前面加上torch.jit.script字样就好。

 model = torch.jit.script(ResNet())

也可以把模型本身变成TorchScript版本(需要有@torch.jit.script注释)。这样初始化的时候就不需要再加torch.jit.script字样。

@torch.jit.script
class myModel:
  def __init__(self, x):
    self.x = x
    
  def forward(self):
  	return self.x * 10

 model = myModel()

需要注意的是myModel中不能将函数隐藏在另外的函数之内。下列情况不被允许。

@torch.jit.script
class myModel:
  def __init__(self, x):
    self.x = x
    
  def forward(self):
  	return adding3()
  	
  def adding3(self):
  	def adding2(self):
  		return 2 + self.x
  	return 1 + adding2(self)

如果myModel中有函数未被forward调用,那么需要在该函数前加上@torch.jit.export注释。例如,

@torch.jit.script
class myModel:
  def __init__(self, x):
    self.x = x
    
  def forward(self):
  	return x * x
  
  @torch.jit.export	
  def adding3(self):
  	return 3 + self.x

附送从Optional[int]转换成int的小技巧。

from typing import TypeVar
T = TypeVar('T')
def cast_away_optional(arg: Optional[T]) -> T:
    """
    Casting. 
    https://github.com/python/typing/issues/645
    """
    assert arg is not None
    return arg

Reference

  • https://pytorch.org/docs/stable/jit_language_reference.html
  • https://huggingface.co/transformers/v2.1.1/torchscript.html
  • https://github.com/python/typing/issues/645