0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

pytorch-geometricにおけるImportError:'NeighborSampler'...

Posted at

GNNにおける学習時にエラーが起きた
ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'

code

問題が起きたのは上の記事での学習と評価...の部分

# 学習と評価
def train(model, loader, device, optimizer, epoch):
    model.train()
    for epoch in range(1, epoch):
        total_loss = total_samples = 0
        for batch_data in tqdm(loader):
            optimizer.zero_grad()
            batch_data = batch_data.to(device)
            pred = model(batch_data)
            loss = F.binary_cross_entropy_with_logits(
                pred, batch_data["blog", "tagged", "tag"].edge_label
            )
            loss.backward()
            optimizer.step()
            total_loss += float(loss) * pred.numel()
            total_samples += pred.numel()
        print(f"Epoch: {epoch:04d}, Loss: {total_loss / total_samples:.4f}")

def validation(model, loader, device, optimizer):
    y_preds = []
    y_trues = []
    model.eval()
    for batch_data in tqdm(loader):
        with torch.no_grad():
            batch_data = batch_data.to(device)
            pred = model(batch_data)
            y_preds.append(pred)
            y_trues.append(batch_data["blog", "tagged", "tag"].edge_label)

    y_pred = torch.cat(y_preds, dim=0).cpu().numpy()
    y_true = torch.cat(y_trues, dim=0).cpu().numpy()
    auc = roc_auc_score(y_true, y_pred)
    return auc, y_pred, y_true


# パラメータセット
model = Model(hidden_channels=64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model = model.to(device)

Error

ImportError                               Traceback (most recent call last)
Cell In[41], line 43
     40 model = model.to(device)
     42 # 学習・評価
---> 43 train(model, train_loader, device, optimizer, 6)
     44 auc, y_pred, y_true = validation(model, val_loader, device, optimizer)
     46 # 精度確認(ROC-AUC曲線)

Cell In[41], line 6, in train(model, loader, device, optimizer, epoch)
      4 for epoch in range(1, epoch):
      5     total_loss = total_samples = 0
----> 6     for batch_data in tqdm(loader):
      7         optimizer.zero_grad()
      8         batch_data = batch_data.to(device)

File /opt/conda/lib/python3.10/site-packages/tqdm/std.py:1182, in tqdm.__iter__(self)
   1179 time = self._time
   1181 try:
-> 1182     for obj in iterable:
   1183         yield obj
   1184         # Update and possibly print the progressbar.
   1185         # Note: does not call self.update(1) for speed optimisation.

File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:631, in _BaseDataLoaderIter.__next__(self)
    628 if self._sampler_iter is None:
    629     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    630     self._reset()  # type: ignore[call-arg]
--> 631 data = self._next_data()
    632 self._num_yielded += 1
    633 if self._dataset_kind == _DatasetKind.Iterable and \
    634         self._IterableDataset_len_called is not None and \
    635         self._num_yielded > self._IterableDataset_len_called:

File /opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py:675, in _SingleProcessDataLoaderIter._next_data(self)
    673 def _next_data(self):
    674     index = self._next_index()  # may raise StopIteration
--> 675     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    676     if self._pin_memory:
    677         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File /opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:54, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     52 else:
     53     data = self.dataset[possibly_batched_index]
---> 54 return self.collate_fn(data)

File /opt/conda/lib/python3.10/site-packages/torch_geometric/loader/link_loader.py:211, in LinkLoader.collate_fn(self, index)
    208 r"""Samples a subgraph from a batch of input edges."""
    209 input_data: EdgeSamplerInput = self.input_data[index]
--> 211 out = self.link_sampler.sample_from_edges(
    212     input_data, neg_sampling=self.neg_sampling)
    214 if self.filter_per_worker:  # Execute `filter_fn` in the worker process
    215     out = self.filter_fn(out)

File /opt/conda/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py:334, in NeighborSampler.sample_from_edges(self, inputs, neg_sampling)
    329 def sample_from_edges(
    330     self,
    331     inputs: EdgeSamplerInput,
    332     neg_sampling: Optional[NegativeSampling] = None,
    333 ) -> Union[SamplerOutput, HeteroSamplerOutput]:
--> 334     out = edge_sample(inputs, self._sample, self.num_nodes, self.disjoint,
    335                       self.node_time, neg_sampling)
    336     if self.subgraph_type == SubgraphType.bidirectional:
    337         out = out.to_bidirectional()

File /opt/conda/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py:666, in edge_sample(inputs, sample_fn, num_nodes, disjoint, node_time, neg_sampling)
    661     if edge_label_time is not None:  # Always disjoint.
    662         seed_time_dict = {
    663             input_type[0]: torch.cat([src_time, dst_time], dim=0),
    664         }
--> 666 out = sample_fn(seed_dict, seed_time_dict)
    668 # Enhance `out` by label information ##################################
    669 if disjoint:

File /opt/conda/lib/python3.10/site-packages/torch_geometric/sampler/neighbor_sampler.py:431, in NeighborSampler._sample(self, seed, seed_time, **kwargs)
    428     num_sampled_nodes = num_sampled_edges = None
    430 else:
--> 431     raise ImportError(f"'{self.__class__.__name__}' requires "
    432                       f"either 'pyg-lib' or 'torch-sparse'")
    434 if num_sampled_edges is not None:
    435     num_sampled_edges = remap_keys(
    436         num_sampled_edges,
    437         self.to_edge_type,
    438     )

ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'

原因と解決策

原因

単純にpytorch-geometricのオプションライブラリ群が入っていない時は入れる

# Install required packages.
import os
os.environ['TORCH'] = torch.__version__

!pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-${TORCH}.html

pythorh-geometric,pyg-libなどを問題なくimportできる場合は,インストールする順番がおかしい

check
import torch_geometric
print(torch_geometric.typing.WITH_PYG_LIB)
import pyg_lib

torch -> pytorch-geometric -> pyg_lib torch_scatter torch_sparse torch_cluster などのオプションライブラリ の順番でインストールすると,参照するpytorch-geometricのversionが異なることでエラーになるらしい

解決策

torch -> pyg_lib torch_scatter torch_sparse torch_cluster などのオプションライブラリ -> pytorch-geometric の順番でインストールする

!pip install torch==2.2.0
# Install required packages.
import os
os.environ['TORCH'] = torch.__version__

!pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-${TORCH}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

参考

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?