summaryrefslogtreecommitdiff
path: root/lib/odp-execute-avx512.c
blob: c28461ec1a0d056b4925bd908505b054c24ee048 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
/*
 * Copyright (c) 2022 Intel.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at:
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifdef __x86_64__
/* Sparse cannot handle the AVX512 instructions. */
#if !defined(__CHECKER__)

#include <config.h>
#include <errno.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <netinet/ip6.h>

#include "csum.h"
#include "dp-packet.h"
#include "immintrin.h"
#include "odp-execute.h"
#include "odp-execute-private.h"
#include "odp-netlink.h"
#include "openvswitch/vlog.h"
#include "packets.h"

VLOG_DEFINE_THIS_MODULE(odp_execute_avx512);

/* The below three build asserts make sure that l2_5_ofs, l3_ofs, and l4_ofs
 * fields remain in the same order and offset to l2_padd_size. This is needed
 * as the avx512_dp_packet_resize_l2() function will manipulate those fields at
 * a fixed memory index based on the l2_padd_size offset. */
BUILD_ASSERT_DECL(offsetof(struct dp_packet, l2_pad_size) +
                  MEMBER_SIZEOF(struct dp_packet, l2_pad_size) ==
                  offsetof(struct dp_packet, l2_5_ofs));

BUILD_ASSERT_DECL(offsetof(struct dp_packet, l2_5_ofs) +
                  MEMBER_SIZEOF(struct dp_packet, l2_5_ofs) ==
                  offsetof(struct dp_packet, l3_ofs));

BUILD_ASSERT_DECL(offsetof(struct dp_packet, l3_ofs) +
                           MEMBER_SIZEOF(struct dp_packet, l3_ofs) ==
                           offsetof(struct dp_packet, l4_ofs));

/* The below build assert makes sure it's safe to read/write 128-bits starting
 * at the l2_pad_size location. */
BUILD_ASSERT_DECL(sizeof(struct dp_packet) -
                  offsetof(struct dp_packet, l2_pad_size) >= sizeof(__m128i));

/* The below build assert makes sure the order of the fields needed by
 * the set masked functions shuffle operations do not change. This should not
 * happen as these are defined under the Linux uapi. */
BUILD_ASSERT_DECL(offsetof(struct ovs_key_ethernet, eth_src) +
                  MEMBER_SIZEOF(struct ovs_key_ethernet, eth_src) ==
                  offsetof(struct ovs_key_ethernet, eth_dst));

BUILD_ASSERT_DECL(offsetof(struct ovs_key_ipv4, ipv4_src) +
                  MEMBER_SIZEOF(struct ovs_key_ipv4, ipv4_src) ==
                  offsetof(struct ovs_key_ipv4, ipv4_dst));

BUILD_ASSERT_DECL(offsetof(struct ovs_key_ipv4, ipv4_dst) +
                  MEMBER_SIZEOF(struct ovs_key_ipv4, ipv4_dst) ==
                  offsetof(struct ovs_key_ipv4, ipv4_proto));

BUILD_ASSERT_DECL(offsetof(struct ovs_key_ipv4, ipv4_proto) +
                  MEMBER_SIZEOF(struct ovs_key_ipv4, ipv4_proto) ==
                  offsetof(struct ovs_key_ipv4, ipv4_tos));

BUILD_ASSERT_DECL(offsetof(struct ovs_key_ipv4, ipv4_tos) +
                  MEMBER_SIZEOF(struct ovs_key_ipv4, ipv4_tos) ==
                  offsetof(struct ovs_key_ipv4, ipv4_ttl));

BUILD_ASSERT_DECL(offsetof(struct ovs_key_ipv6, ipv6_src) +
                  MEMBER_SIZEOF(struct ovs_key_ipv6, ipv6_src) ==
                  offsetof(struct ovs_key_ipv6, ipv6_dst));

BUILD_ASSERT_DECL(offsetof(struct ovs_key_ipv6, ipv6_dst) +
                  MEMBER_SIZEOF(struct ovs_key_ipv6, ipv6_dst) ==
                  offsetof(struct ovs_key_ipv6, ipv6_label));

BUILD_ASSERT_DECL(offsetof(struct ovs_key_ipv6, ipv6_label) +
                  MEMBER_SIZEOF(struct ovs_key_ipv6, ipv6_label) ==
                  offsetof(struct ovs_key_ipv6, ipv6_proto));

BUILD_ASSERT_DECL(offsetof(struct ovs_key_ipv6, ipv6_proto) +
                  MEMBER_SIZEOF(struct ovs_key_ipv6, ipv6_proto) ==
                  offsetof(struct ovs_key_ipv6, ipv6_tclass));

BUILD_ASSERT_DECL(offsetof(struct ovs_key_ipv6, ipv6_tclass) +
                  MEMBER_SIZEOF(struct ovs_key_ipv6, ipv6_tclass) ==
                  offsetof(struct ovs_key_ipv6, ipv6_hlimit));

/* Array of callback functions, one for each masked operation. */
odp_execute_action_cb impl_set_masked_funcs[__OVS_KEY_ATTR_MAX];

static inline void ALWAYS_INLINE
avx512_dp_packet_resize_l2(struct dp_packet *b, int resize_by_bytes)
{
    /* Update packet size/data pointers, same as the scalar implementation. */
    if (resize_by_bytes >= 0) {
        dp_packet_push_uninit(b, resize_by_bytes);
    } else {
        dp_packet_pull(b, -resize_by_bytes);
    }

    /* The next step is to update the l2_5_ofs, l3_ofs and l4_ofs fields which
     * the scalar implementation does with the  dp_packet_adjust_layer_offset()
     * function. */

    /* Set the v_zero register to all zero's. */
    const __m128i v_zeros = _mm_setzero_si128();

    /* Set the v_u16_max register to all one's. */
    const __m128i v_u16_max = _mm_cmpeq_epi16(v_zeros, v_zeros);

    /* Each lane represents 16 bits in a 12-bit register. In this case the
     * first three 16-bit values, which will map to the l2_5_ofs, l3_ofs and
     * l4_ofs fields. */
    const uint8_t k_lanes = 0b1110;

    /* Set all 16-bit words in the 128-bits v_offset register to the value we
     * need to add/substract from the l2_5_ofs, l3_ofs, and l4_ofs fields. */
    __m128i v_offset = _mm_set1_epi16(abs(resize_by_bytes));

    /* Load 128 bits from the dp_packet structure starting at the l2_pad_size
     * offset. */
    void *adjust_ptr = &b->l2_pad_size;
    __m128i v_adjust_src = _mm_loadu_si128(adjust_ptr);

    /* Here is the tricky part, we only need to update the value of the three
     * fields if they are not UINT16_MAX. The following function will return
     * a mask of lanes (read fields) that are not UINT16_MAX. It will do this
     * by comparing only the lanes we requested, k_lanes, and if they match
     * v_u16_max, the bit will be set. */
    __mmask8 k_cmp = _mm_mask_cmpneq_epu16_mask(k_lanes, v_adjust_src,
                                                v_u16_max);

    /* Based on the bytes adjust (positive, or negative) it will do the actual
     * add or subtraction. These functions will only operate on the lanes
     * (fields) requested based on k_cmp, i.e:
     *   k_cmp = [l2_5_ofs, l3_ofs, l4_ofs]
     *   for field in kcmp
     *       v_adjust_src[field] = v_adjust_src[field] + v_offset
     */
    __m128i v_adjust_wip;

    if (resize_by_bytes >= 0) {
        v_adjust_wip = _mm_mask_add_epi16(v_adjust_src, k_cmp,
                                          v_adjust_src, v_offset);
    } else {
        v_adjust_wip = _mm_mask_sub_epi16(v_adjust_src, k_cmp,
                                          v_adjust_src, v_offset);
    }

    /* Here we write back the full 128-bits. */
    _mm_storeu_si128(adjust_ptr, v_adjust_wip);
}

/* This function performs the same operation on each packet in the batch as
 * the scalar eth_pop_vlan() function. */
static void
action_avx512_pop_vlan(struct dp_packet_batch *batch,
                       const struct nlattr *a OVS_UNUSED)
{
    struct dp_packet *packet;

    /* Set the v_zero register to all zero's. */
    const __m128i v_zeros = _mm_setzero_si128();

    DP_PACKET_BATCH_FOR_EACH (i, packet, batch) {
        struct vlan_eth_header *veh = dp_packet_eth(packet);

        if (veh && dp_packet_size(packet) >= sizeof *veh &&
            eth_type_vlan(veh->veth_type)) {

            /* Load the first 128-bits of l2 header into the v_ether register.
             * This result in the veth_dst/src and veth_type/tci of the
             * vlan_eth_header structure to be loaded. */
            __m128i v_ether = _mm_loadu_si128((void *) veh);

            /* This creates a 256-bit value containing the first four fields
             * of the vlan_eth_header plus 128 zero-bit. The result will be the
             * lowest 128-bits after the right shift, hence we shift the data
             * 128(zero)-bits minus the VLAN_HEADER_LEN, so we are left with
             * only the veth_dst and veth_src fields. */
            __m128i v_realign = _mm_alignr_epi8(v_ether, v_zeros,
                                                sizeof(__m128i) -
                                                VLAN_HEADER_LEN);

            /* Write back the modified ethernet header. */
            _mm_storeu_si128((void *) veh, v_realign);

            /* As we removed the VLAN_HEADER we now need to adjust all the
             * offsets. */
            avx512_dp_packet_resize_l2(packet, -VLAN_HEADER_LEN);
        }
    }
}

/* This function performs the same operation on each packet in the batch as
 * the scalar eth_push_vlan() function. */
static void
action_avx512_push_vlan(struct dp_packet_batch *batch, const struct nlattr *a)
{
    struct dp_packet *packet;
    const struct ovs_action_push_vlan *vlan = nl_attr_get(a);
    ovs_be16 tpid, tci;

    /* This shuffle mask is used below, and each position tells where to
     * move the bytes to. So here, the fourth byte in v_ether is moved to
     * byte location 0 in v_shift. The fifth is moved to 1, etc., etc.
     * The 0xFF is special it tells to fill that position with 0. */
    static const uint8_t vlan_push_shuffle_mask[16] = {
        4, 5, 6, 7, 8, 9, 10, 11,
        12, 13, 14, 15, 0xFF, 0xFF, 0xFF, 0xFF
    };

    /* Load the shuffle mask in v_index. */
    __m128i v_index = _mm_loadu_si128((void *) vlan_push_shuffle_mask);

    DP_PACKET_BATCH_FOR_EACH (i, packet, batch) {
        tpid = vlan->vlan_tpid;
        tci = vlan->vlan_tci;

        /* As we are about to insert the VLAN_HEADER we now need to adjust all
         * the offsets. */
        avx512_dp_packet_resize_l2(packet, VLAN_HEADER_LEN);

        char *pkt_data = (char *) dp_packet_data(packet);

        /* Build up the VLAN TCI/TPID in a single uint32_t. */
        const uint32_t tci_proc = tci & htons(~VLAN_CFI);
        const uint32_t tpid_tci = (tci_proc << 16) | tpid;

        /* Load the first 128-bits of the packet into the v_ether register.
         * Note that this includes the 4 unused bytes (VLAN_HEADER_LEN). */
        __m128i v_ether = _mm_loadu_si128((void *) pkt_data);

        /* Move(shuffle) the veth_dst and veth_src data to create room for
         * the vlan header. */
        __m128i v_shift = _mm_shuffle_epi8(v_ether, v_index);

        /* Copy(insert) the 32-bit VLAN header, tpid_tci, at the 3rd 32-bit
         * word offset, i.e. ofssetof(vlan_eth_header, veth_type) */
        __m128i v_vlan_hdr = _mm_insert_epi32(v_shift, tpid_tci, 3);

        /* Write back the modified ethernet header. */
        _mm_storeu_si128((void *) pkt_data, v_vlan_hdr);
    }
}

/* This function performs the same operation on each packet in the batch as
 * the scalar odp_eth_set_addrs() function. */
static void
action_avx512_eth_set_addrs(struct dp_packet_batch *batch,
                            const struct nlattr *a)
{
    const struct ovs_key_ethernet *key, *mask;
    struct dp_packet *packet;

    a = nl_attr_get(a);
    key = nl_attr_get(a);
    mask = odp_get_key_mask(a, struct ovs_key_ethernet);

    /* Read the content of the key(src) and mask in the respective registers.
     * We only load the src and dest addresses, which is only 96-bits and not
     * 128-bits. */
    __m128i v_src = _mm_maskz_loadu_epi32(0x7,(void *) key);
    __m128i v_mask = _mm_maskz_loadu_epi32(0x7, (void *) mask);


    /* These shuffle masks are used below, and each position tells where to
     * move the bytes to. So here, the fourth sixth byte in
     * ovs_key_ethernet is moved to byte location 0 in v_src/v_mask.
     * The seventh is moved to 1, etc., etc.
     * This swap is needed to move the src and dest MAC addresses in the
     * same order as in the ethernet packet. */
    static const uint8_t eth_shuffle[16] = {
        6, 7, 8, 9, 10, 11, 0, 1,
        2, 3, 4, 5, 0xFF, 0xFF, 0xFF, 0xFF
    };

    /* Load the shuffle mask in v_shuf. */
    __m128i v_shuf = _mm_loadu_si128((void *) eth_shuffle);

    /* Swap the key/mask src and dest addresses to the ethernet order. */
    v_src = _mm_shuffle_epi8(v_src, v_shuf);
    v_mask = _mm_shuffle_epi8(v_mask, v_shuf);

    DP_PACKET_BATCH_FOR_EACH (i, packet, batch) {

        struct eth_header *eh = dp_packet_eth(packet);

        if (!eh) {
            continue;
        }

        /* Load the first 128-bits of the packet into the v_ether register. */
        __m128i v_dst = _mm_loadu_si128((void *) eh);

        /* AND the v_mask to the packet data (v_dst). */
        __m128i dst_masked = _mm_andnot_si128(v_mask, v_dst);

        /* OR the new addresses (v_src) with the masked packet addresses
         * (dst_masked). */
        __m128i res = _mm_or_si128(v_src, dst_masked);

        /* Write back the modified ethernet addresses. */
        _mm_storeu_si128((void *) eh, res);
    }
}

static inline uint16_t ALWAYS_INLINE
avx512_get_delta(__m256i old_header, __m256i new_header)
{
    __m256i v_zeros = _mm256_setzero_si256();

    /* These two shuffle masks, v_swap16a and v_swap16b, are to shuffle the
     * old and new header to add padding after each 16-bit value for the
     * following carry over addition. */
    __m256i v_swap16a = _mm256_setr_epi16(0x0100, 0xFFFF, 0x0302, 0xFFFF,
                                          0x0504, 0xFFFF, 0x0706, 0xFFFF,
                                          0x0100, 0xFFFF, 0x0302, 0xFFFF,
                                          0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF);
    __m256i v_swap16b = _mm256_setr_epi16(0x0908, 0xFFFF, 0x0B0A, 0xFFFF,
                                          0x0D0C, 0xFFFF, 0x0F0E, 0xFFFF,
                                          0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF,
                                          0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF);
    __m256i v_shuf_old1 = _mm256_shuffle_epi8(old_header, v_swap16a);
    __m256i v_shuf_old2 = _mm256_shuffle_epi8(old_header, v_swap16b);
    __m256i v_shuf_new1 = _mm256_shuffle_epi8(new_header, v_swap16a);
    __m256i v_shuf_new2 = _mm256_shuffle_epi8(new_header, v_swap16b);

    /* Add each part of the old and new headers together. */
    __m256i v_delta1 = _mm256_add_epi32(v_shuf_old1, v_shuf_new1);
    __m256i v_delta2 = _mm256_add_epi32(v_shuf_old2, v_shuf_new2);

    /* Add old and new header. */
    __m256i v_delta = _mm256_add_epi32(v_delta1, v_delta2);

    /* Perform horizontal add to go from 8x32-bits to 2x32-bits. */
    v_delta = _mm256_hadd_epi32(v_delta, v_zeros);
    v_delta = _mm256_hadd_epi32(v_delta, v_zeros);

    /* Shuffle 32-bit value from 3rd lane into first lane for final
     * horizontal add. */
    __m256i v_swap32a = _mm256_setr_epi32(0x0, 0x4, 0xF, 0xF,
                                          0xF, 0xF, 0xF, 0xF);
    v_delta = _mm256_permutexvar_epi32(v_swap32a, v_delta);

    v_delta = _mm256_hadd_epi32(v_delta, v_zeros);
    v_delta = _mm256_hadd_epi16(v_delta, v_zeros);

    /* Extract delta value. */
    return _mm256_extract_epi16(v_delta, 0);
}

/* This function will calculate the csum delta for the IPv4 addresses in the
 * new_header and old_header, assuming the csum field on the new_header was
 * updated. */
static inline uint16_t ALWAYS_INLINE
avx512_ipv4_addr_csum_delta(__m256i old_header, __m256i new_header)
{
    __m256i v_zeros = _mm256_setzero_si256();

    /* Set the v_ones register to all one's. */
    __m256i v_ones = _mm256_cmpeq_epi16(v_zeros, v_zeros);

    /* Combine the old and new header, i.e. adding in the new IP addresses
     * in the old header (oh). This is done by using the 0x03C 16-bit mask,
     * picking 16-bit word 7 till 10.  */
    __m256i v_blend_new = _mm256_mask_blend_epi16(0x03C0, old_header,
                                                  new_header);

    /* Invert the old_header register. */
    old_header =_mm256_andnot_si256(old_header, v_ones);

    /* Calculate the delta between the old and new header. */
    return avx512_get_delta(old_header, v_blend_new);
}

/* This function will calculate the csum delta between the new_header and
 * old_header, assuming the csum field on the new_header was not yet updated
 * or reset. It also assumes headers contain the first 20-bytes of the IPv4
 * header data, and the rest is zeroed out. */
static inline uint16_t ALWAYS_INLINE
avx512_ipv4_hdr_csum_delta(__m256i old_header, __m256i new_header)
{
    __m256i v_zeros = _mm256_setzero_si256();

    /* Set the v_ones register to all one's. */
    __m256i v_ones = _mm256_cmpeq_epi16(v_zeros, v_zeros);

    /* Invert the old_header register. */
    old_header =_mm256_andnot_si256(old_header, v_ones);

    /* Calculate the delta between the old and new header. */
    return avx512_get_delta(old_header, new_header);
}

/* This function performs the same operation on each packet in the batch as
 * the scalar odp_set_ipv4() function. */
static void
action_avx512_ipv4_set_addrs(struct dp_packet_batch *batch,
                             const struct nlattr *a)
{
    const struct ovs_key_ipv4 *key, *mask;
    struct dp_packet *packet;
    a = nl_attr_get(a);
    key = nl_attr_get(a);
    mask = odp_get_key_mask(a, struct ovs_key_ipv4);

    /* Read the content of the key(src) and mask in the respective registers.
     * We only load the size of the actual structure, which is only 96-bits. */
    __m256i v_key = _mm256_maskz_loadu_epi32(0x7, (void *) key);
    __m256i v_mask = _mm256_maskz_loadu_epi32(0x7, (void *) mask);

    /* This two shuffle masks, v_shuf32, v_shuffle, are to shuffle key and
     * mask to match the ip_header structure layout. */
    static const uint8_t ip_shuffle_mask[32] = {
            0xFF, 0x05, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
            0x06, 0xFF, 0xFF, 0xFF, 0x00, 0x01, 0x02, 0x03,
            0x00, 0x01, 0x02, 0x03, 0xFF, 0xFF, 0xFF, 0xFF,
            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF};

    __m256i v_shuf32 = _mm256_setr_epi32(0x0, 0x2, 0xF, 0xF,
                                         0x1, 0xF, 0xF, 0xF);

    __m256i v_shuffle = _mm256_loadu_si256((void *) ip_shuffle_mask);

    /* Two shuffles are required for key and mask to match the layout of
     * the ip_header struct. The _shuffle_epi8 only works within 128-bit
     * lanes, so a permute is required to move src and dst into the correct
     * lanes. And then a shuffle is used to move the fields into the right
     * order. */
    __m256i v_key_shuf = _mm256_permutexvar_epi32(v_shuf32, v_key);
    v_key_shuf = _mm256_shuffle_epi8(v_key_shuf, v_shuffle);

    __m256i v_mask_shuf = _mm256_permutexvar_epi32(v_shuf32, v_mask);
    v_mask_shuf = _mm256_shuffle_epi8(v_mask_shuf, v_shuffle);

    DP_PACKET_BATCH_FOR_EACH (i, packet, batch) {
        struct ip_header *nh = dp_packet_l3(packet);
        ovs_be16 old_csum = ~nh->ip_csum;

        /* Load the 20 bytes of the IPv4 header. Without options, which is the
         * most common case it's 20 bytes, but can be up to 60 bytes. */
        __m256i v_packet = _mm256_maskz_loadu_epi32(0x1F, (void *) nh);

        /* AND the v_pkt_mask to the packet data (v_packet). */
        __m256i v_pkt_masked = _mm256_andnot_si256(v_mask_shuf, v_packet);

        /* OR the new addresses (v_key_shuf) with the masked packet addresses
         * (v_pkt_masked). */
        __m256i v_new_hdr = _mm256_or_si256(v_key_shuf, v_pkt_masked);

        /* Update the IP checksum based on updated IP values. */
        uint16_t delta = avx512_ipv4_hdr_csum_delta(v_packet, v_new_hdr);
        uint32_t new_csum = old_csum + delta;
        delta = csum_finish(new_csum);

        /* Insert new checksum. */
        v_new_hdr = _mm256_insert_epi16(v_new_hdr, delta, 5);

        /* If ip_src or ip_dst has been modified, L4 checksum needs to
         * be updated too. */
        if (mask->ipv4_src || mask->ipv4_dst) {

            uint16_t delta_checksum = avx512_ipv4_addr_csum_delta(v_packet,
                                                                  v_new_hdr);
            size_t l4_size = dp_packet_l4_size(packet);

            if (nh->ip_proto == IPPROTO_UDP && l4_size >= UDP_HEADER_LEN) {
                /* New UDP checksum. */
                struct udp_header *uh = dp_packet_l4(packet);
                if (uh->udp_csum) {
                    uint16_t old_udp_checksum = ~uh->udp_csum;
                    uint32_t udp_checksum = old_udp_checksum + delta_checksum;
                    udp_checksum = csum_finish(udp_checksum);

                    if (!udp_checksum) {
                        udp_checksum = htons(0xffff);
                    }
                    /* Insert new udp checksum. */
                    uh->udp_csum = udp_checksum;
                }
            } else if (nh->ip_proto == IPPROTO_TCP &&
                       l4_size >= TCP_HEADER_LEN) {
                /* New TCP checksum. */
                struct tcp_header *th = dp_packet_l4(packet);
                uint16_t old_tcp_checksum = ~th->tcp_csum;
                uint32_t tcp_checksum = old_tcp_checksum + delta_checksum;
                tcp_checksum = csum_finish(tcp_checksum);

                th->tcp_csum = tcp_checksum;
            }

            pkt_metadata_init_conn(&packet->md);
        }
        /* Write back the modified IPv4 addresses. */
        _mm256_mask_storeu_epi32((void *) nh, 0x1F, v_new_hdr);
    }
}

#if HAVE_AVX512VBMI
static inline uint16_t ALWAYS_INLINE
__attribute__((__target__("avx512vbmi")))
avx512_ipv6_sum_header(__m512i ip6_header)
{
    __m256i v_zeros = _mm256_setzero_si256();
    __m512i v_shuf_src_dst = _mm512_setr_epi64(0x01, 0x02, 0x03, 0x04,
                                               0xFF, 0xFF, 0xFF, 0xFF);

    /* Shuffle ip6 src and dst to beginning of register. */
    __m512i v_ip6_hdr_shuf = _mm512_permutexvar_epi64(v_shuf_src_dst,
                                                      ip6_header);

    /* Extract ip6 src and dst into smaller 256-bit wide register. */
    __m256i v_ip6_src_dst = _mm512_extracti64x4_epi64(v_ip6_hdr_shuf, 0);

    /* These two shuffle masks, v_swap16a and v_swap16b, are to shuffle the
     * src and dst fields and add padding after each 16-bit value for the
     * following carry over addition. */
    __m256i v_swap16a = _mm256_setr_epi16(0x0100, 0xFFFF, 0x0302, 0xFFFF,
                                          0x0504, 0xFFFF, 0x0706, 0xFFFF,
                                          0x0100, 0xFFFF, 0x0302, 0xFFFF,
                                          0x0504, 0xFFFF, 0x0706, 0xFFFF);
    __m256i v_swap16b = _mm256_setr_epi16(0x0908, 0xFFFF, 0x0B0A, 0xFFFF,
                                          0x0D0C, 0xFFFF, 0x0F0E, 0xFFFF,
                                          0x0908, 0xFFFF, 0x0B0A, 0xFFFF,
                                          0x0D0C, 0xFFFF, 0x0F0E, 0xFFFF);
    __m256i v_shuf_old1 = _mm256_shuffle_epi8(v_ip6_src_dst, v_swap16a);
    __m256i v_shuf_old2 = _mm256_shuffle_epi8(v_ip6_src_dst, v_swap16b);

    /* Add each part of the old and new headers together. */
    __m256i v_delta = _mm256_add_epi32(v_shuf_old1, v_shuf_old2);

    /* Perform horizontal add to go from 8x32-bits to 2x32-bits. */
    v_delta = _mm256_hadd_epi32(v_delta, v_zeros);
    v_delta = _mm256_hadd_epi32(v_delta, v_zeros);

    /* Shuffle 32-bit value from 3rd lane into first lane for final
     * horizontal add. */
    __m256i v_swap32a = _mm256_setr_epi32(0x0, 0x4, 0xF, 0xF,
                                          0xF, 0xF, 0xF, 0xF);

    v_delta = _mm256_permutexvar_epi32(v_swap32a, v_delta);
    v_delta = _mm256_hadd_epi32(v_delta, v_zeros);
    v_delta = _mm256_hadd_epi16(v_delta, v_zeros);

    /* Extract delta value. */
    return _mm256_extract_epi16(v_delta, 0);
}

static inline uint16_t ALWAYS_INLINE
__attribute__((__target__("avx512vbmi")))
avx512_ipv6_addr_csum_delta(__m512i old_header, __m512i new_header)
{
    uint16_t old_delta = avx512_ipv6_sum_header(old_header);
    uint16_t new_delta = avx512_ipv6_sum_header(new_header);
    uint32_t csum_delta = ((uint16_t) ~old_delta) + new_delta;

    return ~csum_finish(csum_delta);
}

/* This function performs the same operation on each packet in the batch as
 * the scalar odp_set_ipv6() function. */
static void
__attribute__((__target__("avx512vbmi")))
action_avx512_set_ipv6(struct dp_packet_batch *batch, const struct nlattr *a)
{
    const struct ovs_key_ipv6 *key, *mask;
    struct dp_packet *packet;

    a = nl_attr_get(a);
    key = nl_attr_get(a);
    mask = odp_get_key_mask(a, struct ovs_key_ipv6);

    /* Read the content of the key and mask in the respective registers. We
     * only load the size of the actual structure, which is only 40 bytes. */
    __m512i v_key = _mm512_maskz_loadu_epi64(0x1F, (void *) key);
    __m512i v_mask = _mm512_maskz_loadu_epi64(0x1F, (void *) mask);

    /* This shuffle mask v_shuffle, is to shuffle key and mask to match the
     * ip6_hdr structure layout. */
    static const uint8_t ip_shuffle_mask[64] = {
        0x20, 0x21, 0x22, 0x23, 0xFF, 0xFF, 0x24, 0x26,
        0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
        0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
        0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
        0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0XFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0XFF, 0xFF
    };

    __m512i v_shuffle = _mm512_loadu_si512((void *) ip_shuffle_mask);

    /* This shuffle is required for key and mask to match the layout of the
     * ip6_hdr struct. */
    __m512i v_key_shuf = _mm512_permutexvar_epi8(v_shuffle, v_key);
    __m512i v_mask_shuf = _mm512_permutexvar_epi8(v_shuffle, v_mask);

    /* Set the v_zero register to all zero's. */
    const __m128i v_zeros = _mm_setzero_si128();

    /* Set the v_all_ones register to all one's. */
    const __m128i v_all_ones = _mm_cmpeq_epi16(v_zeros, v_zeros);

    /* Load ip6 src and dst masks respectively into 128-bit wide registers. */
    __m128i v_src = _mm_loadu_si128((void *) &mask->ipv6_src);
    __m128i v_dst = _mm_loadu_si128((void *) &mask->ipv6_dst);

    /* Perform a bitwise OR between src and dst registers. */
    __m128i v_or = _mm_or_si128(v_src, v_dst);

    /* Will return true if any bit has been set in v_or, else it will return
     * false. */
    bool do_checksum = !_mm_test_all_zeros(v_or, v_all_ones);

    DP_PACKET_BATCH_FOR_EACH (i, packet, batch) {
        struct ovs_16aligned_ip6_hdr *nh = dp_packet_l3(packet);

        /* Load the 40 bytes of the IPv6 header. */
        __m512i v_packet = _mm512_maskz_loadu_epi64(0x1F, (void *) nh);

        /* AND the v_pkt_mask to the packet data (v_packet). */
        __m512i v_pkt_masked = _mm512_andnot_si512(v_mask_shuf, v_packet);

        /* OR the new addresses (v_key_shuf) with the masked packet addresses
         * (v_pkt_masked). */
        __m512i v_new_hdr = _mm512_or_si512(v_key_shuf, v_pkt_masked);

        /* If ip6_src or ip6_dst has been modified, L4 checksum needs to be
         * updated. */
        uint8_t proto = 0;
        bool rh_present;
        bool do_csum = do_checksum;

        rh_present = packet_rh_present(packet, &proto, &do_csum);

        if (do_csum) {
            size_t l4_size = dp_packet_l4_size(packet);
            __m512i v_new_hdr_for_cksum = v_new_hdr;
            uint16_t delta_checksum;

            /* In case of routing header being present, checksum should not be
             * updated for the destination address. */
            if (rh_present) {
                v_new_hdr_for_cksum = _mm512_mask_blend_epi64(0x18, v_new_hdr,
                                                              v_packet);
            }

            delta_checksum = avx512_ipv6_addr_csum_delta(v_packet,
                                                         v_new_hdr_for_cksum);

            if (proto == IPPROTO_UDP && l4_size >= UDP_HEADER_LEN) {
                struct udp_header *uh = dp_packet_l4(packet);

                if (uh->udp_csum) {
                    uint16_t old_udp_checksum = ~uh->udp_csum;
                    uint32_t udp_checksum = old_udp_checksum + delta_checksum;

                    udp_checksum = csum_finish(udp_checksum);

                    if (!udp_checksum) {
                        udp_checksum = htons(0xffff);
                    }

                    uh->udp_csum = udp_checksum;
                }
            } else if (proto == IPPROTO_TCP && l4_size >= TCP_HEADER_LEN) {
                struct tcp_header *th = dp_packet_l4(packet);
                uint16_t old_tcp_checksum = ~th->tcp_csum;
                uint32_t tcp_checksum = old_tcp_checksum + delta_checksum;

                tcp_checksum = csum_finish(tcp_checksum);
                th->tcp_csum = tcp_checksum;
            } else if (proto == IPPROTO_ICMPV6 &&
                       l4_size >= sizeof(struct icmp6_header)) {
                struct icmp6_header *icmp = dp_packet_l4(packet);
                uint16_t old_icmp6_checksum = ~icmp->icmp6_cksum;
                uint32_t icmp6_checksum = old_icmp6_checksum + delta_checksum;

                icmp6_checksum = csum_finish(icmp6_checksum);
                icmp->icmp6_cksum = icmp6_checksum;
            }

            pkt_metadata_init_conn(&packet->md);
        }
        /* Write back the modified IPv6 addresses. */
        _mm512_mask_storeu_epi64((void *) nh, 0x1F, v_new_hdr);
    }
}
#endif /* HAVE_AVX512VBMI */

static void
action_avx512_set_masked(struct dp_packet_batch *batch, const struct nlattr *a)
{
    const struct nlattr *mask = nl_attr_get(a);
    enum ovs_key_attr attr_type = nl_attr_type(mask);

    if (attr_type <= OVS_KEY_ATTR_MAX && impl_set_masked_funcs[attr_type]) {
        impl_set_masked_funcs[attr_type](batch, a);
    } else {
        odp_execute_scalar_action(batch, a);
    }
}

int
action_avx512_init(struct odp_execute_action_impl *self OVS_UNUSED)
{
    if (!action_avx512_isa_probe()) {
        return -ENOTSUP;
    }

    /* Set function pointers for actions that can be applied directly, these
     * are identified by OVS_ACTION_ATTR_*. */
    self->funcs[OVS_ACTION_ATTR_POP_VLAN] = action_avx512_pop_vlan;
    self->funcs[OVS_ACTION_ATTR_PUSH_VLAN] = action_avx512_push_vlan;
    self->funcs[OVS_ACTION_ATTR_SET_MASKED] = action_avx512_set_masked;

    /* Set function pointers for the individual operations supported by the
     * SET_MASKED action. */
    impl_set_masked_funcs[OVS_KEY_ATTR_ETHERNET] = action_avx512_eth_set_addrs;
    impl_set_masked_funcs[OVS_KEY_ATTR_IPV4] = action_avx512_ipv4_set_addrs;

#if HAVE_AVX512VBMI
    if (action_avx512vbmi_isa_probe()) {
        impl_set_masked_funcs[OVS_KEY_ATTR_IPV6] = action_avx512_set_ipv6;
    }
#endif

    return 0;
}

#endif /* Sparse */

#else /* __x86_64__ */

#include <config.h>
#include <errno.h>
#include "odp-execute-private.h"
/* Function itself is required to be called, even in e.g. 32-bit builds.
 * This dummy init function ensures 32-bit builds succeed too.
 */

int
action_avx512_init(struct odp_execute_action_impl *self OVS_UNUSED)
{
  return -ENOTSUP;
}

#endif