Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Leak Issue with ForwardIs Method in gotch Library #133

Open
yinziyang opened this issue Aug 7, 2024 · 4 comments
Open

Memory Leak Issue with ForwardIs Method in gotch Library #133

yinziyang opened this issue Aug 7, 2024 · 4 comments

Comments

@yinziyang
Copy link

Description:

When using the gotch library with a JIT-compiled BERT model, calling the ForwardIs method repeatedly causes a memory leak. The memory usage continuously increases with each call to ForwardIs, which may lead to system instability after prolonged operation.

Steps to Reproduce:

  1. Train and save a BERT model using the following Python code:

    import torch
    from transformers import BertTokenizer, BertModel, BertForSequenceClassification
    import torch.nn as nn
    
    class ScriptableBertForSequenceClassification(BertForSequenceClassification):
        def __init__(self, config):
            super().__init__(config)
            self.bert = BertModel(config)
    
        def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None):
            if input_ids is not None:
                input_shape = input_ids.size()
            else:
                input_shape = inputs_embeds.size()[:-1]
    
            device = input_ids.device if input_ids is not None else inputs_embeds.device
    
            if attention_mask is None:
                attention_mask = torch.ones(input_shape, device=device)
            if token_type_ids is None:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
    
            extended_attention_mask = attention_mask[:, None, None, :]
    
            for param in self.bert.parameters():
                if param is not None:
                    extended_attention_mask = extended_attention_mask.to(dtype=param.dtype)
                    break
    
            extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
    
            embedding_output = self.bert.embeddings(
                input_ids=input_ids,
                position_ids=position_ids,
                inputs_embeds=inputs_embeds,
            )
            encoder_outputs = self.bert.encoder(
                embedding_output,
                attention_mask=extended_attention_mask,
                head_mask=head_mask,
            )
            sequence_output = encoder_outputs[0]
            pooled_output = self.bert.pooler(sequence_output)
    
            logits = self.classifier(pooled_output)
    
            loss = None
            if labels is not None:
                if self.num_labels == 1:
                    loss_fct = nn.MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = nn.CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
    
            return (logits, pooled_output, sequence_output) if loss is None else (loss, logits, pooled_output, sequence_output)
    
    model = ScriptableBertForSequenceClassification.from_pretrained('bert-base-multilingual-cased', num_labels=3)
    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
    input_text = "Hello, this is a test."
    inputs = tokenizer(input_text, return_tensors='pt')
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    token_type_ids = inputs["token_type_ids"]
    
    jit_model = torch.jit.trace(model, (input_ids, attention_mask, token_type_ids))
    torch.jit.save(jit_model, "model.pt")
  2. Load and call the model in Go using the gotch library:

    package main
    
    import (
    	"log"
    
    	"github.com/sugarme/gotch/ts"
    )
    
    func main() {
    	modelFile := "model.pt"
    	model, err := ts.ModuleLoad(modelFile)
    	if err != nil {
    		panic(err)
    	}
    
    	var inputIds = []int32{
    		101, 12865, 11639, 56011, 10908, 10473, 47798, 11424, 83438, 13663, 80017, 74661, 47464, 79326, 10271, 10114, 17734, 3378, 7104, 121, 2075, 2102, 7323, 2534, 3642, 8831, 4151, 7069, 3661, 5605, 3197, 3459, 29653, 6088, 2188, 4380, 2072, 4780, 2435, 7498, 5396, 5718, 73784, 5611, 2452, 2763, 2084, 2090, 3001, 8192, 3701, 2735, 4009, 7740, 5755, 2568, 4792, 73784, 3592, 6336, 6779, 3775, 3378, 2448, 7300, 7321, 7356, 3197, 3408, 4284, 3792, 2275, 5605, 7323, 2534, 2730, 7323, 7315, 5142, 2534, 121, 2075, 2102, 7323, 2534, 8332, 3626, 2080, 8098, 5718, 3661, 5605, 4163, 6748, 10900, 68897, 7700, 2102, 3661, 5605, 5769, 5718, 3199, 3240, 5605, 2146, 3661, 5605, 4982, 5619, 2080, 2688, 8422, 5618, 8335, 5061, 4409, 4252, 2259, 2299, 4142, 5484, 4941, 2105, 3419, 3191, 4577, 2773, 2149, 7838, 3031, 5484, 7700, 2102, 73784, 4462, 2731, 2206, 2756, 2204, 7333, 7333, 5765, 5769, 5718, 4380, 8332, 3626, 5718, 2080, 8098, 3731, 5293, 4163, 6748, 4163, 6748, 4163, 6748, 4333, 2597, 6546, 5396, 4476, 2762, 121, 2316, 3848, 2286, 3661, 2890, 3197, 3459, 121, 11517, 11274, 2465, 3410, 3824, 25986, 10929, 3978, 6457, 4449, 8595, 5396, 4476, 5718, 3378, 2678, 4004, 4482, 7349, 7168, 6088, 2251, 2104, 4313, 3642, 8831, 2534, 6309, 8215, 121, 3507, 2204, 2078, 2211, 4580, 2435, 7498, 2435, 7478, 3173, 4368, 4476, 6397, 6036, 2184, 4181, 2286, 2079, 7651, 4012, 6098, 73784, 2468, 3410, 5760, 2457, 2149, 8417, 5611, 3701, 2735, 6457, 7475, 7478, 2081, 4476, 6397, 6036, 4346, 2457, 102,
    	}
    	var attentionMask = []int32{
    		1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
    	}
    	var tokenTypeIds = []int32{
    		0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    	}
    
    	l := int64(len(inputIds))
    
    	inputParams := make([]*ts.IValue, 3)
    
    	s1 := ts.MustOfSlice(inputIds)
    	ts1 := s1.MustView([]int64{1, l}, true)
    	defer s1.Drop()
    	log.Printf("%+v", ts1)
    	defer ts1.Drop()
    	inputParams[0] = ts.NewIValue(ts1)
    
    	s2 := ts.MustOfSlice(attentionMask)
    	ts2 := s2.MustView([]int64{1, l}, true)
    	defer s2.Drop()
    	log.Printf("%+v", ts2)
    	defer ts2.Drop()
    	inputParams[1] = ts.NewIValue(ts2)
    
    	s3 := ts.MustOfSlice(tokenTypeIds)
    	ts3 := s3.MustView([]int64{1, l}, true)
    	defer s3.Drop()
    	log.Printf("%+v", ts3)
    	defer ts3.Drop()
    	inputParams[2] = ts.NewIValue(ts3)
    
        
    	for {
    		ivs, err := model.ForwardIs(inputParams)   // Calling this line multiple times will lead to increasing memory usage
    		if err != nil {
    			panic(err)
    		}
    		xs := ivs.Value().([]*ts.Tensor)
    		for _, x := range xs {
    			log.Println(x)
    			x.MustDrop()
    		}
    	}
    }

Expected Behavior:

Memory usage should remain stable regardless of the number of ForwardIs calls.

Actual Behavior:

Memory usage continues to increase with each call to ForwardIs, leading to gradual memory consumption over prolonged operation.

Additional Information:

  • Python Version: 3.12.4
  • PyTorch Version: 2.3.1
  • Gotch Version: 0.9.1
  • Transformers Version: 4.43.3
  • Operating System: Ubuntu 22.04
  • Other Libraries: [Specify any other relevant libraries]
@gmohmad
Copy link

gmohmad commented Dec 10, 2024

hey! have you found a solution to this problem? im facing the same issue, the memory consumption just skyrockets under high load and doesn't get freed

@sugarme
Copy link
Owner

sugarme commented Dec 13, 2024

@yinziyang , @gmohmad ,

Maybe you need to wrap your code inside NoGrad(). Please see code at example folder.

@gmohmad
Copy link

gmohmad commented Dec 18, 2024

@sugarme
i've tried wrapping my code in NoGrad too, but the library still consumes lots of memory and doesnt free it immediately. I tried adding debug logs with gotch.PrintMemStats() and ts.CheckCMemLeak, but they show that everything is fine too, all the memory allocated by c code is cleaned up. Im using this library in a high-load server environment, and the memory eventually stacks up and the process gets killed by OOM. but maybe im doing something wrong, heres my code:

func (b *jitModule) forward(t1, t2 []int64, topk int64) ([]int64, []float32, error) {
	var scores []float32
	var names []int64

	ts.NoGrad(func() {
		inputTensor1 := ts.TensorFrom(t1)
		inputTensor2 := ts.TensorFrom(t2)
		defer inputTensor1.MustDrop()
		defer inputTensor2.MustDrop()

		outputTs, err := b.Model.ForwardTs([]*ts.Tensor{inputTensor1, inputTensor2})
		if err != nil {
			return
		}
		defer outputTs.MustDrop()

		outTs1, outTs2, err := outputTs.TopK(topk, -1, true, false)
		if err != nil {
			log.Println("got invalid output from model")
			return
		}
		defer outTs1.MustDrop()
		defer outTs2.MustDrop()

		scores = outTs1.Vals().([]float32)
		names = outTs2.Vals().([]int64)
	})

	return names, scores, nil
}

@gmohmad
Copy link

gmohmad commented Dec 23, 2024

update:
it looks like the memory is actually being freed but just not released back to the os.
when i first deploy the service(k8s) and do load testing, the memory goes up drastically and then doesn't go down, but when i load test it again, the memory only goes up a little bit(compared to the first time) and then goes back to around where it was before the second load.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants