TensorFlow-数据变换-tf.unstack(value, num=None, axis=0, name="unstack")

tf.unstack在循环神经网络的搭建中出现过,这里记录下方便自己记忆

功能:将输入value按照指定axis(维度)拆分(从0开始),输出含有num个元素的列表,num必须和指定维度内元素的个数相等,当然可以忽略不写这个参数,比如:tf.unstack(X, axis=0)

举个例子,假设value.shape为(2,3,4),
如果axis=0,那么num就必须填2,变换后list有2个元素,元素的shape为(3,4)
如果axis=1,那么num就必须填3,变换后list有3个元素,元素的shape为(2,4)
如果axis=2,那么num就必须填4,变换后list有4个元素,元素的shape为(2,3)

import tensorflow as tf
import numpy as np

X = tf.constant(np.array(range(24)).reshape(2, 3, 4))

X0 = tf.unstack(X, 2, 0)
X1 = tf.unstack(X, 3, 1)
X2 = tf.unstack(X, 4, 2)

with tf.Session() as sess:
    ts = [X, X0, X1, X2]
    xs = sess.run([X, X0, X1, X2])
    for t, x in zip(ts, xs):
        print(t, '\n', x, '\n')

将输出手动美化后的结果如下,依次是X,X0,X1,X2

Tensor("Const:0", shape=(2, 3, 4), dtype=int64) 
[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]
  
 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]] 

[, 
 ] 
[array([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]]), 
 array([[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]])] 

[, 
 , 
 ] 
[array([[ 0,  1,  2,  3],
        [12, 13, 14, 15]]), 
 array([[ 4,  5,  6,  7],
        [16, 17, 18, 19]]), 
 array([[ 8,  9, 10, 11],
        [20, 21, 22, 23]])] 

[, 
 , 
 , 
 ] 
[array([[ 0,  4,  8],
        [12, 16, 20]]), 
 array([[ 1,  5,  9],
        [13, 17, 21]]), 
 array([[ 2,  6, 10],
        [14, 18, 22]]), 
 array([[ 3,  7, 11],
        [15, 19, 23]])] 

那么现在分析下循环神经网络中对mnist图片的处理,图片形状是28*28的像素矩阵

n_steps = 28  
n_input = 28 

X = tf.placeholder("float", [None, n_steps, n_input])
X1 = tf.unstack(X, n_steps, 1)

为什么这么写呢,因为循环神经网络适合解决连续序列的问题,所以这里是将一张图片转化为一个序列格式的数据,28行28列的像素矩阵,tf.unstack(X, n_steps, 1)将其转化为一个列表长度为28,列表元素为每行像素点,即,一行为一个序列,每行像素点为序列特征

n_steps = 28  
n_input = 28 

X = tf.placeholder("float", [None, n_steps, n_input])
X1 = tf.unstack(X, n_steps, 2)

那么,这样写呢,让axis=2,其实这只是换了个角度,将每列像素点看做一个序列特征,第一列为序列1,第二列为序列2,第三列为序列3…
TensorFlow-数据变换-tf.unstack(value, num=None, axis=0, name=
为了直观地看出 tf.unstack 的变换,不妨写一段代码

import tensorflow as tf
import numpy as np

X = tf.constant(np.array(range(1 * 28 * 28)).reshape(1, 28, 28))

X1 = tf.unstack(X, 28, 1)
X2 = tf.unstack(X, 28, 2)

with tf.Session() as sess:
    xs = sess.run([X, X1, X2])
    [print(x) for x in xs]

将输出手动美化后的结果如下,分别是X,以及axis=1与axis=2的变换结果

[[[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16 17  18  19  20  21  22  23  24  25  26  27]
  [ 28  29  30  31  32  33  34  35  36  37  38  39  40  41  42  43  44 45  46  47  48  49  50  51  52  53  54  55]
  [ 56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72 73  74  75  76  77  78  79  80  81  82  83]
  [ 84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107 108 109 110 111]
  [112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139]
  [140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167]
  [168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195]
  [196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223]
  [224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251]
  [252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279]
  [280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307]
  [308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335]
  [336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363]
  [364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391]
  [392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419]
  [420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447]
  [448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475]
  [476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503]
  [504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531]
  [532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559]
  [560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587]
  [588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615]
  [616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643]
  [644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671]
  [672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699]
  [700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727]
  [728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755]
  [756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783]]]  

[array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]]), 
 array([[28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55]]), 
 array([[56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83]]), 
 array([[ 84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96, 97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111]]), 
 array([[112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139]]), 
 array([[140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167]]), 
 array([[168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195]]), 
 array([[196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223]]),
 array([[224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236,237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251]]),
 array([[252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279]]),
 array([[280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307]]),
 array([[308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335]]),
 array([[336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363]]),
 array([[364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391]]), 
 array([[392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419]]),
 array([[420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447]]),
 array([[448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475]]),
 array([[476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503]]),
 array([[504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531]]),
 array([[532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559]]),
 array([[560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587]]),
 array([[588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615]]),
 array([[616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643]]),
 array([[644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671]]),
 array([[672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699]]),
 array([[700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727]]),
 array([[728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755]]),
 array([[756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783]])]

[array([[  0,  28,  56,  84, 112, 140, 168, 196, 224, 252, 280, 308, 336, 364, 392, 420, 448, 476, 504, 532, 560, 588, 616, 644, 672, 700, 728, 756]]),
 array([[  1,  29,  57,  85, 113, 141, 169, 197, 225, 253, 281, 309, 337, 365, 393, 421, 449, 477, 505, 533, 561, 589, 617, 645, 673, 701, 729, 757]]),
 array([[  2,  30,  58,  86, 114, 142, 170, 198, 226, 254, 282, 310, 338, 366, 394, 422, 450, 478, 506, 534, 562, 590, 618, 646, 674, 702, 730, 758]]),
 array([[  3,  31,  59,  87, 115, 143, 171, 199, 227, 255, 283, 311, 339, 367, 395, 423, 451, 479, 507, 535, 563, 591, 619, 647, 675, 703, 731, 759]]),
 array([[  4,  32,  60,  88, 116, 144, 172, 200, 228, 256, 284, 312, 340, 368, 396, 424, 452, 480, 508, 536, 564, 592, 620, 648, 676, 704, 732, 760]]),
 array([[  5,  33,  61,  89, 117, 145, 173, 201, 229, 257, 285, 313, 341, 369, 397, 425, 453, 481, 509, 537, 565, 593, 621, 649, 677, 705, 733, 761]]),
 array([[  6,  34,  62,  90, 118, 146, 174, 202, 230, 258, 286, 314, 342, 370, 398, 426, 454, 482, 510, 538, 566, 594, 622, 650, 678, 706, 734, 762]]),
 array([[  7,  35,  63,  91, 119, 147, 175, 203, 231, 259, 287, 315, 343, 371, 399, 427, 455, 483, 511, 539, 567, 595, 623, 651, 679, 707, 735, 763]]),
 array([[  8,  36,  64,  92, 120, 148, 176, 204, 232, 260, 288, 316, 344, 372, 400, 428, 456, 484, 512, 540, 568, 596, 624, 652, 680, 708, 736, 764]]),
 array([[  9,  37,  65,  93, 121, 149, 177, 205, 233, 261, 289, 317, 345, 373, 401, 429, 457, 485, 513, 541, 569, 597, 625, 653, 681, 709, 737, 765]]),
 array([[ 10,  38,  66,  94, 122, 150, 178, 206, 234, 262, 290, 318, 346, 374, 402, 430, 458, 486, 514, 542, 570, 598, 626, 654, 682, 710, 738, 766]]),
 array([[ 11,  39,  67,  95, 123, 151, 179, 207, 235, 263, 291, 319, 347, 375, 403, 431, 459, 487, 515, 543, 571, 599, 627, 655, 683, 711, 739, 767]]),
 array([[ 12,  40,  68,  96, 124, 152, 180, 208, 236, 264, 292, 320, 348, 376, 404, 432, 460, 488, 516, 544, 572, 600, 628, 656, 684, 712, 740, 768]]),
 array([[ 13,  41,  69,  97, 125, 153, 181, 209, 237, 265, 293, 321, 349, 377, 405, 433, 461, 489, 517, 545, 573, 601, 629, 657, 685, 713, 741, 769]]),
 array([[ 14,  42,  70,  98, 126, 154, 182, 210, 238, 266, 294, 322, 350, 378, 406, 434, 462, 490, 518, 546, 574, 602, 630, 658, 686, 714, 742, 770]]),
 array([[ 15,  43,  71,  99, 127, 155, 183, 211, 239, 267, 295, 323, 351, 379, 407, 435, 463, 491, 519, 547, 575, 603, 631, 659, 687, 715, 743, 771]]),
 array([[ 16,  44,  72, 100, 128, 156, 184, 212, 240, 268, 296, 324, 352, 380, 408, 436, 464, 492, 520, 548, 576, 604, 632, 660, 688, 716, 744, 772]]),
 array([[ 17,  45,  73, 101, 129, 157, 185, 213, 241, 269, 297, 325, 353, 381, 409, 437, 465, 493, 521, 549, 577, 605, 633, 661, 689, 717, 745, 773]]),
 array([[ 18,  46,  74, 102, 130, 158, 186, 214, 242, 270, 298, 326, 354, 382, 410, 438, 466, 494, 522, 550, 578, 606, 634, 662, 690, 718, 746, 774]]),
 array([[ 19,  47,  75, 103, 131, 159, 187, 215, 243, 271, 299, 327, 355, 383, 411, 439, 467, 495, 523, 551, 579, 607, 635, 663, 691, 719, 747, 775]]),
 array([[ 20,  48,  76, 104, 132, 160, 188, 216, 244, 272, 300, 328, 356, 384, 412, 440, 468, 496, 524, 552, 580, 608, 636, 664, 692, 720, 748, 776]]),
 array([[ 21,  49,  77, 105, 133, 161, 189, 217, 245, 273, 301, 329, 357, 385, 413, 441, 469, 497, 525, 553, 581, 609, 637, 665, 693, 721, 749, 777]]),
 array([[ 22,  50,  78, 106, 134, 162, 190, 218, 246, 274, 302, 330, 358, 386, 414, 442, 470, 498, 526, 554, 582, 610, 638, 666, 694, 722, 750, 778]]),
 array([[ 23,  51,  79, 107, 135, 163, 191, 219, 247, 275, 303, 331, 359, 387, 415, 443, 471, 499, 527, 555, 583, 611, 639, 667, 695, 723, 751, 779]]),
 array([[ 24,  52,  80, 108, 136, 164, 192, 220, 248, 276, 304, 332, 360, 388, 416, 444, 472, 500, 528, 556, 584, 612, 640, 668, 696, 724, 752, 780]]),
 array([[ 25,  53,  81, 109, 137, 165, 193, 221, 249, 277, 305, 333, 361, 389, 417, 445, 473, 501, 529, 557, 585, 613, 641, 669, 697, 725, 753, 781]]),
 array([[ 26,  54,  82, 110, 138, 166, 194, 222, 250, 278, 306, 334, 362, 390, 418, 446, 474, 502, 530, 558, 586, 614, 642, 670, 698, 726, 754, 782]]),
 array([[ 27,  55,  83, 111, 139, 167, 195, 223, 251, 279, 307, 335, 363, 391, 419, 447, 475, 503, 531, 559, 587, 615, 643, 671, 699, 727, 755, 783]])]

你可能感兴趣的:(TensorFlow)