|
6 | 6 | "metadata": { |
7 | 7 | "collapsed": true, |
8 | 8 | "ExecuteTime": { |
9 | | - "end_time": "2025-02-11T05:54:20.866680Z", |
10 | | - "start_time": "2025-02-11T05:54:14.109633Z" |
| 9 | + "end_time": "2025-02-13T04:32:49.529025Z", |
| 10 | + "start_time": "2025-02-13T04:32:36.307509Z" |
11 | 11 | } |
12 | 12 | }, |
13 | 13 | "source": [ |
|
781 | 781 | ], |
782 | 782 | "execution_count": 58 |
783 | 783 | }, |
| 784 | + { |
| 785 | + "metadata": { |
| 786 | + "ExecuteTime": { |
| 787 | + "end_time": "2025-02-12T08:47:02.233458Z", |
| 788 | + "start_time": "2025-02-12T08:47:02.228397Z" |
| 789 | + } |
| 790 | + }, |
| 791 | + "cell_type": "code", |
| 792 | + "source": [ |
| 793 | + "# 张量形状操作\n", |
| 794 | + "# 1. reshape\n", |
| 795 | + "t = torch.tensor([[10, 20, 30], [40, 50, 60]])\n", |
| 796 | + "# 1. 使用 shape 属性或者 size 方法都可以获得张量的形状\n", |
| 797 | + "print(t.shape, t.shape[0], t.shape[1])\n", |
| 798 | + "print(t.size(), t.size()[0], t.size()[1])\n", |
| 799 | + "# 2. 使用 reshape 函数修改张量形状\n", |
| 800 | + "t_ = t.reshape(2, -1) # -1 表示自动计算\n", |
| 801 | + "print(t_)" |
| 802 | + ], |
| 803 | + "id": "5568ad0c04e0f8a2", |
| 804 | + "outputs": [ |
| 805 | + { |
| 806 | + "name": "stdout", |
| 807 | + "output_type": "stream", |
| 808 | + "text": [ |
| 809 | + "torch.Size([2, 3]) 2 3\n", |
| 810 | + "torch.Size([2, 3]) 2 3\n", |
| 811 | + "tensor([[10, 20, 30],\n", |
| 812 | + " [40, 50, 60]])\n" |
| 813 | + ] |
| 814 | + } |
| 815 | + ], |
| 816 | + "execution_count": 16 |
| 817 | + }, |
| 818 | + { |
| 819 | + "metadata": { |
| 820 | + "ExecuteTime": { |
| 821 | + "end_time": "2025-02-12T08:56:41.735840Z", |
| 822 | + "start_time": "2025-02-12T08:56:41.708111Z" |
| 823 | + } |
| 824 | + }, |
| 825 | + "cell_type": "code", |
| 826 | + "source": [ |
| 827 | + "# 2. transpose, permute\n", |
| 828 | + "torch.manual_seed(0)\n", |
| 829 | + "t = torch.tensor(np.random.randint(0, 10, [3, 4, 5]))\n", |
| 830 | + "\n", |
| 831 | + "t1 = t.reshape(4, 3, 5) # 重新计算维度\n", |
| 832 | + "t2 = torch.transpose(t, 0, 1) # 直接交换 0、1维, 一次仅能交换2个维度\n", |
| 833 | + "t3 = torch.permute(t, (1, 0, 2)) # 交换 1、0维, 一次交换任意个维度\n", |
| 834 | + "\n", |
| 835 | + "t1.shape, t2.shape, t3.shape" |
| 836 | + ], |
| 837 | + "id": "59ec85414ba67a8b", |
| 838 | + "outputs": [ |
| 839 | + { |
| 840 | + "data": { |
| 841 | + "text/plain": [ |
| 842 | + "(torch.Size([4, 3, 5]), torch.Size([4, 3, 5]), torch.Size([4, 3, 5]))" |
| 843 | + ] |
| 844 | + }, |
| 845 | + "execution_count": 3, |
| 846 | + "metadata": {}, |
| 847 | + "output_type": "execute_result" |
| 848 | + } |
| 849 | + ], |
| 850 | + "execution_count": 3 |
| 851 | + }, |
| 852 | + { |
| 853 | + "metadata": { |
| 854 | + "ExecuteTime": { |
| 855 | + "end_time": "2025-02-12T09:31:47.417704Z", |
| 856 | + "start_time": "2025-02-12T09:31:47.400613Z" |
| 857 | + } |
| 858 | + }, |
| 859 | + "cell_type": "code", |
| 860 | + "source": [ |
| 861 | + "# 3. view, contiguous\n", |
| 862 | + "t = torch.tensor([[10, 20, 30], [40, 50, 60]])\n", |
| 863 | + "\n", |
| 864 | + "print(\"是否连续:\", t1.is_contiguous())\n", |
| 865 | + "t1 = t.view(3, 2) # 连续的, 可运行\n", |
| 866 | + "print(t1.shape)\n", |
| 867 | + "\n", |
| 868 | + "t2 = torch.transpose(t, 0, 1)\n", |
| 869 | + "print(\"是否连续:\", t2.is_contiguous())\n", |
| 870 | + "#t2 = t2.view(2, 3) # 形状改变, 非连续的, 不可运行\n", |
| 871 | + "t2 = t2.contiguous().view(2, 3) # 调用 contiguous() 方法, 使张量连续\n", |
| 872 | + "print(t2.shape)\n" |
| 873 | + ], |
| 874 | + "id": "1cc799ebfe58da2d", |
| 875 | + "outputs": [ |
| 876 | + { |
| 877 | + "name": "stdout", |
| 878 | + "output_type": "stream", |
| 879 | + "text": [ |
| 880 | + "是否连续: True\n", |
| 881 | + "torch.Size([3, 2])\n", |
| 882 | + "是否连续: False\n", |
| 883 | + "torch.Size([2, 3])\n" |
| 884 | + ] |
| 885 | + } |
| 886 | + ], |
| 887 | + "execution_count": 10 |
| 888 | + }, |
| 889 | + { |
| 890 | + "metadata": { |
| 891 | + "ExecuteTime": { |
| 892 | + "end_time": "2025-02-12T09:46:28.281353Z", |
| 893 | + "start_time": "2025-02-12T09:46:28.275795Z" |
| 894 | + } |
| 895 | + }, |
| 896 | + "cell_type": "code", |
| 897 | + "source": [ |
| 898 | + "# 4. squeeze, unsqueeze\n", |
| 899 | + "\n", |
| 900 | + "torch.manual_seed(0)\n", |
| 901 | + "t = torch.tensor(np.random.randint(0, 10, [1, 3, 1, 5]))\n", |
| 902 | + "print(t.shape)\n", |
| 903 | + "\n", |
| 904 | + "t1 = t.squeeze() # 删除所有1维\n", |
| 905 | + "print(t1.shape)\n", |
| 906 | + "\n", |
| 907 | + "t2 = t.squeeze(0) # 删除指定的1维, 若指定位置非1维, 则不删\n", |
| 908 | + "print(t2.shape)\n", |
| 909 | + "\n", |
| 910 | + "t3 = t1.unsqueeze(-1)# 在指定位置添加1维, -1 表示最后一个维度\n", |
| 911 | + "print(t3.shape)\n" |
| 912 | + ], |
| 913 | + "id": "681ada560059274e", |
| 914 | + "outputs": [ |
| 915 | + { |
| 916 | + "name": "stdout", |
| 917 | + "output_type": "stream", |
| 918 | + "text": [ |
| 919 | + "torch.Size([1, 3, 1, 5])\n", |
| 920 | + "torch.Size([3, 5])\n", |
| 921 | + "torch.Size([3, 1, 5])\n", |
| 922 | + "torch.Size([3, 5, 1])\n" |
| 923 | + ] |
| 924 | + } |
| 925 | + ], |
| 926 | + "execution_count": 17 |
| 927 | + }, |
| 928 | + { |
| 929 | + "metadata": { |
| 930 | + "ExecuteTime": { |
| 931 | + "end_time": "2025-02-12T09:59:35.412382Z", |
| 932 | + "start_time": "2025-02-12T09:59:35.348267Z" |
| 933 | + } |
| 934 | + }, |
| 935 | + "cell_type": "code", |
| 936 | + "source": [ |
| 937 | + "# 张量运算函数\n", |
| 938 | + "torch.manual_seed(0)\n", |
| 939 | + "t = torch.tensor(np.random.randint(0, 10, [2, 3]), dtype=torch.float64)\n", |
| 940 | + "# 1. 均值\n", |
| 941 | + "t1 = t.mean()\n", |
| 942 | + "t2 = t.mean(dim=0) # 按列求均值\n", |
| 943 | + "t3 = t.mean(dim=1) # 按行求均值\n", |
| 944 | + "# 2. 求和\n", |
| 945 | + "t4 = t.sum() # 全部求和\n", |
| 946 | + "#t = t.sum(dim=0)\n", |
| 947 | + "# 3. 平方\n", |
| 948 | + "t5 = t.pow(2) # 每个元素平方\n", |
| 949 | + "# 4. 平方根\n", |
| 950 | + "t6 = t.sqrt() # 每个元素平方根\n", |
| 951 | + "# 5. e^\n", |
| 952 | + "t7 = t.exp() # 每个元素exp\n", |
| 953 | + "# 6. log\n", |
| 954 | + "t8 = t.log() # 每个元素log, 默认以e为底\n", |
| 955 | + "t9 = t.log10() # 每个元素log, 以10为底\n", |
| 956 | + "\n", |
| 957 | + "t, t1, t2, t3, t4, t5, t6, t7, t8, t9\n" |
| 958 | + ], |
| 959 | + "id": "6d5f486598dc0804", |
| 960 | + "outputs": [ |
| 961 | + { |
| 962 | + "data": { |
| 963 | + "text/plain": [ |
| 964 | + "(tensor([[9., 0., 0.],\n", |
| 965 | + " [4., 9., 0.]], dtype=torch.float64),\n", |
| 966 | + " tensor(3.6667, dtype=torch.float64),\n", |
| 967 | + " tensor([6.5000, 4.5000, 0.0000], dtype=torch.float64),\n", |
| 968 | + " tensor([3.0000, 4.3333], dtype=torch.float64),\n", |
| 969 | + " tensor(22., dtype=torch.float64),\n", |
| 970 | + " tensor([[81., 0., 0.],\n", |
| 971 | + " [16., 81., 0.]], dtype=torch.float64),\n", |
| 972 | + " tensor([[3., 0., 0.],\n", |
| 973 | + " [2., 3., 0.]], dtype=torch.float64),\n", |
| 974 | + " tensor([[8.1031e+03, 1.0000e+00, 1.0000e+00],\n", |
| 975 | + " [5.4598e+01, 8.1031e+03, 1.0000e+00]], dtype=torch.float64),\n", |
| 976 | + " tensor([[2.1972, -inf, -inf],\n", |
| 977 | + " [1.3863, 2.1972, -inf]], dtype=torch.float64),\n", |
| 978 | + " tensor([[0.9542, -inf, -inf],\n", |
| 979 | + " [0.6021, 0.9542, -inf]], dtype=torch.float64))" |
| 980 | + ] |
| 981 | + }, |
| 982 | + "execution_count": 23, |
| 983 | + "metadata": {}, |
| 984 | + "output_type": "execute_result" |
| 985 | + } |
| 986 | + ], |
| 987 | + "execution_count": 23 |
| 988 | + }, |
| 989 | + { |
| 990 | + "metadata": { |
| 991 | + "ExecuteTime": { |
| 992 | + "end_time": "2025-02-13T06:12:09.605484Z", |
| 993 | + "start_time": "2025-02-13T06:12:09.505066Z" |
| 994 | + } |
| 995 | + }, |
| 996 | + "cell_type": "code", |
| 997 | + "source": "", |
| 998 | + "id": "4e0bb83e57097993", |
| 999 | + "outputs": [ |
| 1000 | + { |
| 1001 | + "name": "stdout", |
| 1002 | + "output_type": "stream", |
| 1003 | + "text": [ |
| 1004 | + "1\n" |
| 1005 | + ] |
| 1006 | + } |
| 1007 | + ], |
| 1008 | + "execution_count": 2 |
| 1009 | + }, |
784 | 1010 | { |
785 | 1011 | "metadata": {}, |
786 | 1012 | "cell_type": "code", |
787 | 1013 | "outputs": [], |
788 | 1014 | "execution_count": null, |
789 | 1015 | "source": "", |
790 | | - "id": "5568ad0c04e0f8a2" |
| 1016 | + "id": "ac8e388022657ab4" |
791 | 1017 | } |
792 | 1018 | ], |
793 | 1019 | "metadata": { |
|
0 commit comments