Skip to content

Latest commit

 

History

History

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 

The models were trained using LLaMA-Factory.

Please replace the modeling_qwen2.py file in the Transformers package (version 4.43.2), located at transformers/models/qwen2/modeling_qwen2.py, with the provided version.

More specifically, we insert the TKPO loss code into the original the def forward function class Qwen2ForCausalLM(Qwen2PreTrainedModel), as shown below

loss = None
tkpo_loss = True
if labels is not None:
    if tkpo_loss:
        tau = 0.15
        lambda_loss = 0.5
        shift_logits = logits[...,:-1,:].contiguous()
        shift_probs_pos = torch.softmax(shift_logits, dim=-1)
        pos_token = labels[..., 1:].contiguous()
        pos_tokencp = copy.deepcopy(pos_token)
        pos_tokencp[pos_token==-100] = 100
        pos_prob = torch.gather(shift_probs_pos, 2, pos_tokencp.unsqueeze(-1)).squeeze(-1)
        pos_prob = torch.exp(pos_prob/tau)
        with torch.no_grad():
            ## {[104179, 117599]:"文学诗意", [100267, 105310]:"真实客观", [100700, 100011]:"详细全面", [110485,  99936]:“简洁关键”}
            ### language style
            input_ids_anti = input_ids.clone()
            if "104179, 117599" in str(input_ids[0, :].clone().cpu().detach().numpy().tolist()):
                input_ids_anti[:, :20][input_ids_anti[:, :20]==104179] = 100267
                input_ids_anti[:, :20][input_ids_anti[:, :20]==117599] = 105310
            else:
                input_ids_anti[:, :20][input_ids_anti[:, :20]==100267] = 104179
                input_ids_anti[:, :20][input_ids_anti[:, :20]==105310] = 117599
            """
            ### level-of-detail
            input_ids_anti = input_ids.clone()
            if "100700, 100011" in str(input_ids[0, :].clone().cpu().detach().numpy().tolist()):
                input_ids_anti[:, :20][input_ids_anti[:, :20]==100700] = 110485
                input_ids_anti[:, :20][input_ids_anti[:, :20]==100011] = 99936
            elif "100267, 105310" in str(input_ids[0, :].clone().cpu().detach().numpy().tolist()):
                input_ids_anti[:, :20][input_ids_anti[:, :20]==100267] = 110485
                input_ids_anti[:, :20][input_ids_anti[:, :20]==105310] = 99936
            else:
                input_ids_anti[:, :20][input_ids_anti[:, :20]==110485] = 100267
                input_ids_anti[:, :20][input_ids_anti[:, :20]==99936] = 105310
            """ 
            outputs_anti = self.model(
                        input_ids=input_ids_anti.to(input_ids.dtype),
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                        past_key_values=past_key_values,
                        inputs_embeds=inputs_embeds,
                        use_cache=use_cache,
                        output_attentions=output_attentions,
                        output_hidden_states=output_hidden_states,
                        return_dict=return_dict,
                        cache_position=cache_position,
                    )
            hidden_states_anit = outputs_anti[0]
            logits_anti = self.lm_head(hidden_states_anit)[...,:-1,:].contiguous()
            logits_neg = logits_anti.float()
            logits_neg.scatter_(2, pos_tokencp.unsqueeze(-1), -1e6)
            # _, neg_token = torch.topk(logits_neg, 20, dim=-1)
            # neg_token = reject_sampling_tensor(torch.softmax(logits_neg, dim=-1), k=20)
            neg_token = torch.multinomial(torch.softmax(logits_neg, dim=-1).squeeze(0), 20, replacement=False).unsqueeze(0)
            neg_probs = torch.gather(shift_probs_pos, 2, neg_token)
        
        neg_pos_prob = torch.sum(torch.exp(neg_probs/tau),dim=-1) + pos_prob
        losses_tkpo = pos_prob/neg_pos_prob
        mask_label = torch.ones_like(pos_token)
        mask_label[pos_token==-100] = 0
        loss_tkpo = -torch.log(torch.sum((losses_tkpo * mask_label),dim=-1)/torch.sum(mask_label,dim=-1))


        shift_logits_ce = shift_logits.view(-1, self.config.vocab_size)
        shift_labels_ce = pos_token.view(-1)
        loss_fct = CrossEntropyLoss()
        loss_ce = loss_fct(shift_logits_ce, shift_labels_ce)
        loss = loss_ce + lambda_loss*loss_tkpo.squeeze()
    else:
        ## CE loss
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        shift_logits = shift_logits.view(-1, self.config.vocab_size)
        shift_labels = shift_labels.view(-1)
        loss_fct = CrossEntropyLoss()