DevGang
Авторизоваться

Функции split() и vsplit() в PyTorch

Функция split() может разделить 1D или более D тензор на 1 или несколько тензоров, как показано ниже. Установив размерность для второго аргумента, можно выбрать позицию разделения тензора:

import torch

my_tensor = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])

torch.split(my_tensor, 1)
my_tensor.split(1)
# (tensor([[0, 1, 2, 3]]),
#  tensor([[4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

torch.split(my_tensor, 2)
my_tensor.split(2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

torch.split(my_tensor, 3)
my_tensor.split(3)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)

torch.split(my_tensor, (0, 3))
my_tensor.split((0, 3))
# (tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.split(my_tensor, (1, 2))
my_tensor.split((1, 2))
# (tensor([[0, 1, 2, 3]]),
#  tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))

torch.split(my_tensor, (2, 1))
my_tensor.split((2, 1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

torch.split(my_tensor, (3, 0))
my_tensor.split((3, 0))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(0, 4), dtype=torch.int64))

torch.split(my_tensor, (1, 1, 1))
my_tensor.split((1, 1, 1))
# (tensor([[0, 1, 2, 3]]),
#  tensor([[4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

Функция vsplit() может вертикально разбить 2D или более D тензор на 1 или несколько тензоров, как показано ниже. Установив размерность для второго аргумента, можно выбрать позицию разделения тензора:

import torch

my_tensor = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])

torch.vsplit(my_tensor, 1)
my_tensor.vsplit(1)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)

torch.vsplit(my_tensor, 3)
my_tensor.vsplit(3)
# (tensor([[0, 1, 2, 3]]),
#  tensor([[4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

torch.vsplit(my_tensor, (0, 0))
my_tensor.vsplit((0, 0))
# (tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (0, 1))
my_tensor.vsplit((0, 1))
# (tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[0, 1, 2, 3]]),
#  tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (0, 2))
my_tensor.vsplit((0, 2))
# (tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

torch.vsplit(my_tensor, (0, 3))
my_tensor.vsplit((0, 3))
# (tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(0, 4), dtype=torch.int64))

torch.vsplit(my_tensor, (1, 0))
my_tensor.vsplit((1, 0))
# (tensor([[0, 1, 2, 3]]),
#  tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (1, 1))
my_tensor.vsplit((1, 1))
# (tensor([[0, 1, 2, 3]]),
#  tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (1, 2))
my_tensor.vsplit((1, 2))
# (tensor([[0, 1, 2, 3]]), 
#  tensor([[4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

torch.vsplit(my_tensor, (1, 3))
my_tensor.vsplit((1, 3))
# (tensor([[0, 1, 2, 3]]),
#  tensor([[4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(0, 4), dtype=torch.int64))

torch.vsplit(my_tensor, (2, 0))
my_tensor.vsplit((2, 0))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
#  tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (2, 1))
my_tensor.vsplit((2, 1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
#  tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (2, 2))
my_tensor.vsplit((2, 2))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
#  tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[8, 9, 10, 11]]))

torch.vsplit(my_tensor, (2, 3))
my_tensor.vsplit((2, 3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]),
#  tensor([], size=(0, 4), dtype=torch.int64))

torch.vsplit(my_tensor, (3, 0))
my_tensor.vsplit((3, 0))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (3, 1))
my_tensor.vsplit((3, 1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))

torch.vsplit(my_tensor, (3, 2))
my_tensor.vsplit((3, 2))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[8, 9, 10, 11]]))

torch.vsplit(my_tensor, (3, 3))
my_tensor.vsplit((3, 3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([], size=(0, 4), dtype=torch.int64))

Ознакомьтесь с дополнительными операциями по работе с тензорами в PyTorch здесь

Источник:

#Python
Комментарии
Чтобы оставить комментарий, необходимо авторизоваться

Присоединяйся в тусовку

В этом месте могла бы быть ваша реклама

Разместить рекламу