| # Copyright (c) 2022 Intel and/or its affiliates. |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at: |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| from ProtocolHeaderAttribute import * |
| from ProtocolHeaderField import * |
| from InputFormat import * |
| import ExpressionConverter |
| import copy |
| |
| |
| class ProtocolHeader: |
| def __init__(self, node): |
| self.fields = [] |
| self.attributes = [] |
| self.fieldDict = {} |
| self.attributeDict = {} |
| self.Buffer = [] |
| self.Mask = [] |
| |
| self.node = node |
| for field in self.node.fields: |
| phf = ProtocolHeaderField(field.Size, field.DefaultValue, None, field) |
| self.fields.append(phf) |
| if field.Name != "reserved": |
| self.fieldDict[field.Name] = phf |
| |
| for attr in self.node.attributes: |
| pha = ProtocolHeaderAttribute(attr.Size, attr.DefaultValue, attr) |
| self.attributes.append(pha) |
| self.attributeDict[attr.Name] = pha |
| |
| def Name(self): |
| return self.node.Name |
| |
| def Fields(self): |
| return self.fields |
| |
| def Attributes(self): |
| return self.attributes |
| |
| def setField(self, name, expression, auto): |
| if name == "reserved": |
| return False |
| |
| if name not in self.fieldDict: |
| return False |
| |
| field = self.fieldDict[name] |
| |
| if field.UpdateValue(expression, auto): |
| field.UpdateSize() |
| return True |
| |
| return False |
| |
| def SetField(self, name, expression): |
| return self.setField(name, expression, False) |
| |
| def SetFieldAuto(self, name, expression): |
| return self.setField(name, expression, True) |
| |
| def SetAttribute(self, name, expression): |
| if name not in self.attributeDict: |
| return False |
| attr = self.attributeDict[name] |
| |
| return attr.UpdateValue(expression) |
| |
| def SetMask(self, name, expression): |
| if name not in self.fieldDict: |
| return False |
| field = self.fieldDict[name] |
| |
| return field.UpdateMask(expression) |
| |
| def resolveOptional(self, condition): |
| if condition == None: |
| return True |
| |
| tokens = condition.split("|") |
| |
| if len(tokens) > 1: |
| result = False |
| |
| for token in tokens: |
| result |= self.resolveOptional(token) |
| |
| return result |
| |
| tokens = condition.split("&") |
| |
| if len(tokens) > 1: |
| result = True |
| |
| for token in tokens: |
| result &= self.resolveOptional(token) |
| |
| return result |
| |
| key = None |
| value = None |
| |
| if "!=" in tokens[0]: |
| index = tokens[0].find("!=") |
| key = tokens[0][:index].strip() |
| value = tokens[0][index + 1 :].strip() |
| elif "=" in tokens[0]: |
| index = tokens[0].find("=") |
| key = tokens[0][:index].strip() |
| value = tokens[0][index + 1 :].strip() |
| else: |
| return False |
| |
| if key not in self.fieldDict: |
| return False |
| |
| f = self.fieldDict[key] |
| return ExpressionConverter.Equal(f.Value, value) |
| |
| def resolveSize(self, exp): |
| shift = 0 |
| key = exp |
| |
| if "<<" in exp: |
| offset = exp.find("<<") |
| key = exp[0:offset].strip() |
| shift = int(exp[offset + 2 :].strip()) |
| |
| if key in self.fieldDict: |
| field = self.fieldDict[key] |
| _, u16 = ExpressionConverter.ToNum(field.Value) |
| if u16: |
| return u16 << shift |
| else: |
| return 0 |
| |
| if key in self.attributeDict: |
| attr = self.attributeDict[key] |
| _, u16 = ExpressionConverter.ToNum(attr.Value) |
| if u16: |
| return u16 << shift |
| else: |
| return 0 |
| |
| return 0 |
| |
| def Adjust(self): |
| autoIncreases = [] |
| increaseHeaders = [] |
| |
| self.resolveAllSize() |
| |
| for phf in self.fields: |
| if phf.Field.IsAutoIncrease: |
| autoIncreases.append(phf) |
| if phf.Field.IsIncreaseLength and self.resolveOptional(phf.Field.Optional): |
| increaseHeaders.append(phf) |
| |
| for f1 in autoIncreases: |
| for f2 in increaseHeaders: |
| f1.UpdateValue( |
| ExpressionConverter.IncreaseValue(f1.Value, f2.Size >> 3), True |
| ) |
| |
| def resolveAllSize(self): |
| for phf in self.fields: |
| if phf.Field.Optional != None and not self.resolveOptional( |
| phf.Field.Optional |
| ): |
| size = 0 |
| else: |
| if phf.Field.VariableSize != None: |
| size = self.resolveSize(phf.Field.VariableSize) |
| else: |
| size = phf.Field.Size |
| phf.Size = size |
| |
| def GetSize(self): |
| size = 0 |
| |
| for field in self.fields: |
| size += field.Size |
| |
| return size >> 3 |
| |
| def AppendAuto(self, size): |
| for phf in self.fields: |
| if not phf.Field.IsAutoIncrease: |
| continue |
| |
| phf.UpdateValue(ExpressionConverter.IncreaseValue(phf.Value, size), True) |
| |
| def getField(self, name): |
| if name not in self.fieldDict: |
| return None |
| field = self.fieldDict[name] |
| |
| return field.Value |
| |
| def getAttribute(self, name): |
| if name not in self.attributeDict: |
| return None |
| |
| return self.attributeDict[name].Value |
| |
| def GetValue(self, name): |
| result = self.getField(name) |
| |
| if result == None: |
| return self.getAttribute(name) |
| |
| return result |
| |
| def appendNum(self, big, exp, size): |
| num = 0 |
| if exp != None: |
| _, num = ExpressionConverter.ToNum(exp) |
| if num == None: |
| print("Invalid byte expression") |
| return None |
| |
| # cut msb |
| num = num & ((1 << size) - 1) |
| big = big << size |
| big = big | num |
| return big |
| |
| def appendUInt64(self, big, exp, size): |
| u64 = 0 |
| if exp != None: |
| _, u64 = ExpressionConverter.ToNum(exp) |
| if not u64: |
| print("Invalid UInt32 expression") |
| return False |
| |
| # cut msb |
| if size < 64: |
| u64 = u64 & ((1 << size) - 1) |
| big = big << size |
| big = big | u64 |
| return big |
| |
| def appendIPv4(self, big, exp): |
| ipv4 = bytes(4) |
| if exp != None: |
| _, ipv4 = ExpressionConverter.ToIPv4Address(exp) |
| if not ipv4: |
| print("Inavalid IPv4 Address") |
| return False |
| |
| for i in range(len(ipv4)): |
| big = big << 8 |
| big = big | ipv4[i] |
| |
| return big |
| |
| def appendIPv6(self, big, exp): |
| ipv6 = bytes(16) |
| if exp != None: |
| _, ipv6 = ExpressionConverter.ToIPv6Address(exp) |
| if not ipv6: |
| print("Inavalid IPv6 Address") |
| return False |
| |
| for i in range(16): |
| big = big << 8 |
| big = big | ipv6[i] |
| |
| return big |
| |
| def appendMAC(self, big, exp): |
| mac = bytes(6) |
| if exp != None: |
| _, mac = ExpressionConverter.ToMacAddress(exp) |
| if not mac: |
| print("Inavalid MAC Address") |
| return False |
| |
| for i in range(6): |
| big = big << 8 |
| big = big | mac[i] |
| |
| return big |
| |
| def appendByteArray(self, big, exp, size): |
| array = bytes(size >> 3) |
| if exp != None: |
| _, array = ExpressionConverter.ToByteArray(exp) |
| if not array: |
| print("Invalid byte array") |
| return False |
| |
| for i in range(size >> 3): |
| big = big << 8 |
| if i < len(array): |
| big = big | array[i] |
| |
| return big |
| |
| def append(self, big, phf): |
| bigVal = big["bigVal"] |
| bigMsk = big["bigMsk"] |
| |
| if phf.Field.IsReserved: |
| bigVal <<= phf.Size |
| bigMsk <<= phf.Size |
| big.update(bigVal=bigVal, bigMsk=bigMsk) |
| return big, phf.Size |
| |
| size = phf.Size |
| |
| if ( |
| phf.Field.Format == InputFormat.u8 |
| or phf.Field.Format == InputFormat.u16 |
| or phf.Field.Format == InputFormat.u32 |
| ): |
| bigVal = self.appendNum(bigVal, phf.Value, size) |
| bigMsk = self.appendNum(bigMsk, phf.Mask, size) |
| |
| elif phf.Field.Format == InputFormat.u64: |
| bigVal = self.appendUInt64(bigVal, phf.Value, size) |
| bigMsk = self.appendUInt64(bigMsk, phf.Mask, size) |
| |
| elif phf.Field.Format == InputFormat.ipv4: |
| bigVal = self.appendIPv4(bigVal, phf.Value) |
| bigMsk = self.appendIPv4(bigMsk, phf.Mask) |
| |
| elif phf.Field.Format == InputFormat.ipv6: |
| bigVal = self.appendIPv6(bigVal, phf.Value) |
| bigMsk = self.appendIPv6(bigMsk, phf.Mask) |
| |
| elif phf.Field.Format == InputFormat.mac: |
| bigVal = self.appendMAC(bigVal, phf.Value) |
| bigMsk = self.appendMAC(bigMsk, phf.Mask) |
| |
| elif phf.Field.Format == InputFormat.bytearray: |
| bigVal = self.appendByteArray(bigVal, phf.Value, size) |
| bigMsk = self.appendByteArray(bigMsk, phf.Mask, size) |
| |
| else: |
| print("Invalid input format") |
| |
| big.update(bigVal=bigVal, bigMsk=bigMsk) |
| return big, size |
| |
| def Resolve(self): |
| big = {"bigVal": 0, "bigMsk": 0} |
| offset = 0 |
| |
| for phf in self.fields: |
| if phf.Size == 0: |
| continue |
| |
| big, bits = self.append(big, phf) |
| |
| offset += bits |
| |
| byteList1 = [] |
| byteList2 = [] |
| |
| bigVal = big["bigVal"] |
| bigMsk = big["bigMsk"] |
| |
| while offset > 0: |
| byteList1.append(bigVal & 0xFF) |
| byteList2.append(bigMsk & 0xFF) |
| bigVal = bigVal >> 8 |
| bigMsk = bigMsk >> 8 |
| offset -= 8 |
| |
| byteList1.reverse() |
| byteList2.reverse() |
| buffer = copy.deepcopy(byteList1) |
| mask = copy.deepcopy(byteList2) |
| |
| self.Buffer = buffer |
| self.Mask = mask |