GNNにおける学習時にエラーが起きた
ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'
code
問題が起きたのは上の記事での学習と評価...の部分
# 学習と評価
def train(model, loader, device, optimizer, epoch):
model.train()
for epoch in range(1, epoch):
total_loss = total_samples = 0
for batch_data in tqdm(loader):
optimizer.zero_grad()
batch_data = batch_data.to(device)
pred = model(batch_data)
loss = F.binary_cross_entropy_with_logits(
pred, batch_data["blog", "tagged", "tag"].edge_label
)
loss.backward()
optimizer.step()
total_loss += float(loss) * pred.numel()
total_samples += pred.numel()
print(f"Epoch: {epoch:04d}, Loss: {total_loss / total_samples:.4f}")
def validation(model, loader, device, optimizer):
y_preds = []
y_trues = []
model.eval()
for batch_data in tqdm(loader):
with torch.no_grad():
batch_data = batch_data.to(device)
pred = model(batch_data)
y_preds.append(pred)
y_trues.append(batch_data["blog", "tagged", "tag"].edge_label)
y_pred = torch.cat(y_preds, dim=0).cpu().numpy()
y_true = torch.cat(y_trues, dim=0).cpu().numpy()
auc = roc_auc_score(y_true, y_pred)
return auc, y_pred, y_true
# パラメータセット
model = Model(hidden_channels=64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model = model.to(device)
Error
ImportError Traceback (most recent call last)
Cell In[41], line 43
40 model = model.to(device)
42 # 学習・評価
---> 43 train(model, train_loader, device, optimizer, 6)
44 auc, y_pred, y_true = validation(model, val_loader, device, optimizer)
46 # 精度確認(ROC-AUC曲線)
Cell In[41], line 6, in train(model, loader, device, optimizer, epoch)
4 for epoch in range(1, epoch):
5 total_loss = total_samples = 0
----> 6 for batch_data in tqdm(loader):
7 optimizer.zero_grad()
8 batch_data = batch_data.to(device)
File /opt/conda/lib/python3.10/site-packages/tqdm/std.py:1182, in tqdm.__iter__(self)
1179 time = self._time
1181 try:
-> 1182 for obj in iterable:
1183 yield obj
1184 # Update and possibly print the progressbar.
1185 # Note: does not call self.update(1) for speed optimisation.
File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:631, in _BaseDataLoaderIter.__next__(self)
628 if self._sampler_iter is None:
629 # TODO(https://github.com/pytorch/pytorch/issues/76750)
630 self._reset() # type: ignore[call-arg]
--> 631 data = self._next_data()
632 self._num_yielded += 1
633 if self._dataset_kind == _DatasetKind.Iterable and \
634 self._IterableDataset_len_called is not None and \
635 self._num_yielded > self._IterableDataset_len_called:
File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:675, in _SingleProcessDataLoaderIter._next_data(self)
673 def _next_data(self):
674 index = self._next_index() # may raise StopIteration
--> 675 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
676 if self._pin_memory:
677 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File /opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:54, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
52 else:
53 data = self.dataset[possibly_batched_index]
---> 54 return self.collate_fn(data)
File /opt/conda/lib/python3.10/site-packages/torch_geometric/loader/link_loader.py:211, in LinkLoader.collate_fn(self, index)
208 r"""Samples a subgraph from a batch of input edges."""
209 input_data: EdgeSamplerInput = self.input_data[index]
--> 211 out = self.link_sampler.sample_from_edges(
212 input_data, neg_sampling=self.neg_sampling)
214 if self.filter_per_worker: # Execute `filter_fn` in the worker process
215 out = self.filter_fn(out)
File /opt/conda/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py:334, in NeighborSampler.sample_from_edges(self, inputs, neg_sampling)
329 def sample_from_edges(
330 self,
331 inputs: EdgeSamplerInput,
332 neg_sampling: Optional[NegativeSampling] = None,
333 ) -> Union[SamplerOutput, HeteroSamplerOutput]:
--> 334 out = edge_sample(inputs, self._sample, self.num_nodes, self.disjoint,
335 self.node_time, neg_sampling)
336 if self.subgraph_type == SubgraphType.bidirectional:
337 out = out.to_bidirectional()
File /opt/conda/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py:666, in edge_sample(inputs, sample_fn, num_nodes, disjoint, node_time, neg_sampling)
661 if edge_label_time is not None: # Always disjoint.
662 seed_time_dict = {
663 input_type[0]: torch.cat([src_time, dst_time], dim=0),
664 }
--> 666 out = sample_fn(seed_dict, seed_time_dict)
668 # Enhance `out` by label information ##################################
669 if disjoint:
File /opt/conda/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py:431, in NeighborSampler._sample(self, seed, seed_time, **kwargs)
428 num_sampled_nodes = num_sampled_edges = None
430 else:
--> 431 raise ImportError(f"'{self.__class__.__name__}' requires "
432 f"either 'pyg-lib' or 'torch-sparse'")
434 if num_sampled_edges is not None:
435 num_sampled_edges = remap_keys(
436 num_sampled_edges,
437 self.to_edge_type,
438 )
ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'
原因と解決策
原因
単純にpytorch-geometricのオプションライブラリ群が入っていない時は入れる
# Install required packages.
import os
os.environ['TORCH'] = torch.__version__
!pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-${TORCH}.html
pythorh-geometric,pyg-libなどを問題なくimportできる場合は,インストールする順番がおかしい
check
import torch_geometric
print(torch_geometric.typing.WITH_PYG_LIB)
import pyg_lib
torch -> pytorch-geometric -> pyg_lib torch_scatter torch_sparse torch_cluster などのオプションライブラリ
の順番でインストールすると,参照するpytorch-geometricのversionが異なることでエラーになるらしい
解決策
torch -> pyg_lib torch_scatter torch_sparse torch_cluster などのオプションライブラリ -> pytorch-geometric
の順番でインストールする
!pip install torch==2.2.0
# Install required packages.
import os
os.environ['TORCH'] = torch.__version__
!pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-${TORCH}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git
参考