import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
# -------------------------
# Dummy Dataset
# -------------------------
texts = [
"This employment agreement is between company and employee.",
"The court dismissed the appeal.",
"Employee shall maintain confidentiality.",
"The judge granted bail to the accused."
]
# 0 = Contract
# 1 = Court Case
labels = [0, 1, 0, 1]
# -------------------------
# Tokenizer
# -------------------------
MODEL_NAME = "nlpaueb/legal-bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
class LegalDataset(Dataset):
def __init__(self, texts, labels):
self.texts = texts
self.labels = labels
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
enc = tokenizer(
self.texts[idx],
truncation=True,
padding="max_length",
max_length=128,
return_tensors="pt"
)
return {
"input_ids": enc["input_ids"].squeeze(),
"attention_mask": enc["attention_mask"].squeeze(),
"label": torch.tensor(self.labels[idx])
}
dataset = LegalDataset(texts, labels)
loader = DataLoader(
dataset,
batch_size=2,
shuffle=True
)
# -------------------------
# Custom Model
# -------------------------
class LegalClassifier(nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModel.from_pretrained(
MODEL_NAME
)
self.classifier = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, 2)
)
def forward(
self,
input_ids,
attention_mask
):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
cls_embedding = outputs.last_hidden_state[:, 0, :]
logits = self.classifier(
cls_embedding
)
return logits
model = LegalClassifier()
# -------------------------
# Training
# -------------------------
optimizer = torch.optim.AdamW(
model.parameters(),
lr=2e-5
)
criterion = nn.CrossEntropyLoss()
EPOCHS = 3
model.train()
for epoch in range(EPOCHS):
total_loss = 0
for batch in loader:
optimizer.zero_grad()
logits = model(
batch["input_ids"],
batch["attention_mask"]
)
loss = criterion(
logits,
batch["label"]
)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(
f"Epoch {epoch+1} Loss = {total_loss:.4f}"
)
# -------------------------
# Prediction
# -------------------------
model.eval()
test_text = "This agreement shall remain confidential."
enc = tokenizer(
test_text,
return_tensors="pt",
truncation=True,
padding=True
)
with torch.no_grad():
logits = model(
enc["input_ids"],
enc["attention_mask"]
)
pred = torch.argmax(
logits,
dim=1
)
print("Prediction:", pred.item())
Agar tu BERT freeze karna chahe:
for p in self.bert.parameters():
p.requires_grad = Falseinit main likh de isko
No comments:
Post a Comment