在我们之前的文章中,我们介绍了研究人员提出了一种挑战transformer的新架构mamba。
他们的研究表明,Mamba是一种状态空间模型(SSM),在多种模式(如语言、音频和时间序列)中展现出了卓越的性能。为了证明这一点,研究人员使用Mamba-3B模型进行了语言建模实验。该模型超越了同等大小的Transformer模型,并且在预训练和下游评估期间,其表现与大小为其两倍的Transformer模型相当。
Mamba的独特之处在于其快速处理能力、选择性SSM层以及受FlashAttention启发的硬件友好设计。这些特点使Mamba超越了Transformer(Transformer没有传统的注意力和MLP块)。
许多人希望亲自测试Mamba的效果,因此本文整理了一个可以在Colab上完整运行的Mamba代码示例,并使用了Mamba官方的3B模型进行实际运行测试。
首先,我们需要安装依赖,这是官网推荐的:
!pip install causal-conv1d==1.0.0!pip install mamba-ssm==1.0.1
接下来,直接使用transformers库读取预训练的Mamba-3B模型:
import torchimport osfrom transformers import AutoTokenizerfrom mamba_ssm.models.mixer_seq_simple import MambaLMHeadModeltokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")model = MambaLMHeadModel.from_pretrained(os.path.expanduser("state-spaces/mamba-2.8b"), device="cuda", dtype=torch.bfloat16)
可以看到,3B的模型大小为11G。
面试猫
AI面试助手,在线面试神器,助你轻松拿Offer
39 查看详情

然后进行内容生成测试:
tokens = tokenizer("What is the meaning of life", return_tensors="pt")input_ids = tokens.input_ids.to(device="cuda")max_length = input_ids.shape[1] + 80fn = lambda: model.generate( input_ids=input_ids, max_length=max_length, cg=True, return_dict_in_generate=True, output_scores=True, enable_timing=False, temperature=0.1, top_k=10, top_p=0.1,)out = fn()print(tokenizer.decode(out[0][0]))
这里还有一个聊天示例:
import torchfrom transformers import AutoTokenizerfrom mamba_ssm.models.mixer_seq_simple import MambaLMHeadModeldevice = "cuda"tokenizer = AutoTokenizer.from_pretrained("havenhq/mamba-chat")tokenizer.eos_token = ""tokenizer.pad_token = tokenizer.eos_tokentokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_templatemodel = MambaLMHeadModel.from_pretrained("havenhq/mamba-chat", device="cuda", dtype=torch.float16)messages = []user_message = """What is the date for announcement On August 10 said that its arm JSW Neo Energy has agreed to buy a portfolio of 1753 mega watt renewable energy generation capacity from Mytrah Energy India Pvt Ltd for Rs 10,530 crore."""messages.append(dict(role="user", content=user_message))input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)decoded = tokenizer.batch_decode(out)messages.append(dict(role="assistant", content=decoded[0].split("n")[-1]))print("Model:", decoded[0].split("n")[-1])
我已经将所有代码整理成Colab Notebook,有兴趣的可以直接使用:
https://www.php.cn/link/767593ee1911f484bc931f9a10f34b66
以上就是在Colab上测试Mamba的详细内容,更多请关注创想鸟其它相关文章!
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 chuangxiangniao@163.com 举报,一经查实,本站将立刻删除。
发布者:程序猿,转转请注明出处:https://www.chuangxiangniao.com/p/459195.html
微信扫一扫
支付宝扫一扫