Este blog ejemplificará el método para utilizar el método “torch.argmax()” en PyTorch.
¿Cómo utilizar el método “torch.argmax()” en PyTorch?
El método “torch.argmax()” toma cualquier tensor 1D o 2D como entrada y devuelve un tensor que contiene los índices/índices de los valores máximos a lo largo de la dimensión dada.
La sintaxis del método “torch.argmax()” se proporciona a continuación:
antorcha. argmax ( < tensor_entrada > )
Para utilizar este método en PyTorch, consulte los siguientes ejemplos para una mejor comprensión:
Ejemplo 1: utilizar el método “torch.argmax()” con tensor 1D
En el primer ejemplo, crearemos un tensor 1D y usaremos el método 'torch.argmax()' con él. Sigamos el siguiente procedimiento paso a paso:
Paso 1: importar la biblioteca PyTorch
Primero, importe el ' antorcha ”biblioteca para usar el método “torch.argmax()”:
importar antorchaPaso 2: crear tensor 1D
Luego, crea un tensor 1D e imprime sus elementos. Aquí, estamos creando lo siguiente ' Decenas1 ' tensor de una lista usando el ' antorcha.tensor() ' función:
Decenas1 = antorcha. tensor ( [ 5 , 0 , - 8 , 1 , 9 , 7 ] )
imprimir ( Decenas1 )
Esto ha creado un tensor 1D como se ve a continuación:
Paso 3: encontrar índices de valor máximo
Ahora, utilice el ' antorcha.argmax() 'Función para encontrar el índice/índices del valor máximo en el' Decenas1 'tensor:
T1_ind = antorcha. argmax ( Decenas1 )Paso 4: Imprimir índice de valor máximo
Por último, muestre el índice del valor máximo en el tensor de entrada:
imprimir ( 'Índices:' , T1_ind )El siguiente resultado muestra el índice del valor máximo en el ' Decenas1 'tensor, es decir, 4. Significa que el valor más alto del tensor está en el cuarto índice que es' 9 ”:
Ejemplo 2: utilizar el método “torch.argmax()” con tensor 2D
En el segundo ejemplo, crearemos un tensor 2D y usaremos el método 'torch.argmax()' con él. Sigamos los pasos proporcionados:
Paso 1: importar la biblioteca PyTorch
Primero, importe el ' antorcha ”biblioteca para usar el método “torch.argmax()”:
importar antorchaPaso 2: crear tensor 2D
Luego, utilice el botón ' antorcha.tensor() ”Función para crear un tensor 2D e imprimir sus elementos. Aquí, estamos creando lo siguiente “ decenas2 “Tensor 2D:
decenas2 = antorcha. tensor ( [ [ 4 , 1 , - 7 ] , [ 15 , 6 , 0 ] , [ - 7 , 9 , 2 ] ] )imprimir ( decenas2 )
Esto ha creado un tensor 2D como se ve a continuación:
Paso 3: encontrar índices de valor máximo
Ahora, encuentre el índice del valor máximo en el “ decenas2 ' tensor utilizando el ' antorcha.argmax() ' función:
T2_ind = antorcha. argmax ( decenas2 )Paso 4: Imprimir índice de valor máximo
Finalmente, muestre el índice del valor máximo en el tensor de entrada:
imprimir ( 'Índices:' , T2_ind )Según el siguiente resultado, el índice del valor máximo en el ' decenas2 El tensor es “3”. Significa que el valor más alto del tensor está en el tercer índice que es ' 15 ”:
Paso 5: encontrar índices de valor máximo a lo largo de las columnas
Además, los usuarios también pueden encontrar los índices/índices de los valores máximos a lo largo de cada columna de un tensor. Por ejemplo, podemos utilizar el ' tenue=0 ”argumento con la función “torch.argmax()”. Encuentra los índices de los valores máximos a lo largo de las columnas del cuadro ' decenas2 ”tensor y luego imprime esos índices:
índice_columna = antorcha. argmax ( decenas2 , oscuro = 0 )imprimir ( 'Índices en columnas:' , índice_columna )
El siguiente resultado muestra los índices de los valores máximos a lo largo de cada columna del tensor:
Paso 6: encontrar índices de valor máximo a lo largo de las filas
De manera similar, los usuarios también pueden encontrar los índices/índices de los valores máximos a lo largo de cada fila de un tensor. Por ejemplo, utilice el ' tenue=1 ”argumento con la función “torch.argmax()” para encontrar los índices de los valores máximos a lo largo de las filas en el tensor “Tens2” y luego imprimir esos índices:
indice_row = antorcha. argmax ( decenas2 , oscuro = 1 )imprimir ( 'Índices en filas:' , indice_row )
Los índices del valor máximo a lo largo de cada fila de un tensor 'Decenas2' se pueden ver a continuación:
Hemos explicado de manera eficiente el método para usar el método “torch.argmax()” en PyTorch.
Nota : Puede acceder a nuestro Google Colab Notebook en este enlace .
Conclusión
Para utilizar el método 'torch.argmax()' en PyTorch, primero importe el archivo ' antorcha ' biblioteca. Luego, cree el tensor 1D o 2D deseado y vea sus elementos. A continuación, utilice el botón ' antorcha.argmax() 'Método para encontrar/calcular los índices/índices de los valores máximos en el tensor. Además, los usuarios también pueden encontrar los índices del valor máximo a lo largo de cada fila o columna del tensor usando el botón ' oscuro ' argumento. Finalmente, muestre el índice del valor máximo en el tensor de entrada. Este blog ha ejemplificado el método para utilizar el método “torch.argmax()” en PyTorch.