blob: 398a52d34551cf2db25b378c360d5fa03fe20995 [file] [log] [blame]
# Copyright (c) 2022 Intel and/or its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ProtocolHeaderAttribute import *
from ProtocolHeaderField import *
from InputFormat import *
import ExpressionConverter
import copy
class ProtocolHeader:
def __init__(self, node):
self.fields = []
self.attributes = []
self.fieldDict = {}
self.attributeDict = {}
self.Buffer = []
self.Mask = []
self.node = node
for field in self.node.fields:
phf = ProtocolHeaderField(field.Size, field.DefaultValue, None, field)
self.fields.append(phf)
if field.Name != "reserved":
self.fieldDict[field.Name] = phf
for attr in self.node.attributes:
pha = ProtocolHeaderAttribute(attr.Size, attr.DefaultValue, attr)
self.attributes.append(pha)
self.attributeDict[attr.Name] = pha
def Name(self):
return self.node.Name
def Fields(self):
return self.fields
def Attributes(self):
return self.attributes
def setField(self, name, expression, auto):
if name == "reserved":
return False
if name not in self.fieldDict:
return False
field = self.fieldDict[name]
if field.UpdateValue(expression, auto):
field.UpdateSize()
return True
return False
def SetField(self, name, expression):
return self.setField(name, expression, False)
def SetFieldAuto(self, name, expression):
return self.setField(name, expression, True)
def SetAttribute(self, name, expression):
if name not in self.attributeDict:
return False
attr = self.attributeDict[name]
return attr.UpdateValue(expression)
def SetMask(self, name, expression):
if name not in self.fieldDict:
return False
field = self.fieldDict[name]
return field.UpdateMask(expression)
def resolveOptional(self, condition):
if condition == None:
return True
tokens = condition.split("|")
if len(tokens) > 1:
result = False
for token in tokens:
result |= self.resolveOptional(token)
return result
tokens = condition.split("&")
if len(tokens) > 1:
result = True
for token in tokens:
result &= self.resolveOptional(token)
return result
key = None
value = None
if "!=" in tokens[0]:
index = tokens[0].find("!=")
key = tokens[0][:index].strip()
value = tokens[0][index + 1 :].strip()
elif "=" in tokens[0]:
index = tokens[0].find("=")
key = tokens[0][:index].strip()
value = tokens[0][index + 1 :].strip()
else:
return False
if key not in self.fieldDict:
return False
f = self.fieldDict[key]
return ExpressionConverter.Equal(f.Value, value)
def resolveSize(self, exp):
shift = 0
key = exp
if "<<" in exp:
offset = exp.find("<<")
key = exp[0:offset].strip()
shift = int(exp[offset + 2 :].strip())
if key in self.fieldDict:
field = self.fieldDict[key]
_, u16 = ExpressionConverter.ToNum(field.Value)
if u16:
return u16 << shift
else:
return 0
if key in self.attributeDict:
attr = self.attributeDict[key]
_, u16 = ExpressionConverter.ToNum(attr.Value)
if u16:
return u16 << shift
else:
return 0
return 0
def Adjust(self):
autoIncreases = []
increaseHeaders = []
self.resolveAllSize()
for phf in self.fields:
if phf.Field.IsAutoIncrease:
autoIncreases.append(phf)
if phf.Field.IsIncreaseLength and self.resolveOptional(phf.Field.Optional):
increaseHeaders.append(phf)
for f1 in autoIncreases:
for f2 in increaseHeaders:
f1.UpdateValue(
ExpressionConverter.IncreaseValue(f1.Value, f2.Size >> 3), True
)
def resolveAllSize(self):
for phf in self.fields:
if phf.Field.Optional != None and not self.resolveOptional(
phf.Field.Optional
):
size = 0
else:
if phf.Field.VariableSize != None:
size = self.resolveSize(phf.Field.VariableSize)
else:
size = phf.Field.Size
phf.Size = size
def GetSize(self):
size = 0
for field in self.fields:
size += field.Size
return size >> 3
def AppendAuto(self, size):
for phf in self.fields:
if not phf.Field.IsAutoIncrease:
continue
phf.UpdateValue(ExpressionConverter.IncreaseValue(phf.Value, size), True)
def getField(self, name):
if name not in self.fieldDict:
return None
field = self.fieldDict[name]
return field.Value
def getAttribute(self, name):
if name not in self.attributeDict:
return None
return self.attributeDict[name].Value
def GetValue(self, name):
result = self.getField(name)
if result == None:
return self.getAttribute(name)
return result
def appendNum(self, big, exp, size):
num = 0
if exp != None:
_, num = ExpressionConverter.ToNum(exp)
if num == None:
print("Invalid byte expression")
return None
# cut msb
num = num & ((1 << size) - 1)
big = big << size
big = big | num
return big
def appendUInt64(self, big, exp, size):
u64 = 0
if exp != None:
_, u64 = ExpressionConverter.ToNum(exp)
if not u64:
print("Invalid UInt32 expression")
return False
# cut msb
if size < 64:
u64 = u64 & ((1 << size) - 1)
big = big << size
big = big | u64
return big
def appendIPv4(self, big, exp):
ipv4 = bytes(4)
if exp != None:
_, ipv4 = ExpressionConverter.ToIPv4Address(exp)
if not ipv4:
print("Inavalid IPv4 Address")
return False
for i in range(len(ipv4)):
big = big << 8
big = big | ipv4[i]
return big
def appendIPv6(self, big, exp):
ipv6 = bytes(16)
if exp != None:
_, ipv6 = ExpressionConverter.ToIPv6Address(exp)
if not ipv6:
print("Inavalid IPv6 Address")
return False
for i in range(16):
big = big << 8
big = big | ipv6[i]
return big
def appendMAC(self, big, exp):
mac = bytes(6)
if exp != None:
_, mac = ExpressionConverter.ToMacAddress(exp)
if not mac:
print("Inavalid MAC Address")
return False
for i in range(6):
big = big << 8
big = big | mac[i]
return big
def appendByteArray(self, big, exp, size):
array = bytes(size >> 3)
if exp != None:
_, array = ExpressionConverter.ToByteArray(exp)
if not array:
print("Invalid byte array")
return False
for i in range(size >> 3):
big = big << 8
if i < len(array):
big = big | array[i]
return big
def append(self, big, phf):
bigVal = big["bigVal"]
bigMsk = big["bigMsk"]
if phf.Field.IsReserved:
bigVal <<= phf.Size
bigMsk <<= phf.Size
big.update(bigVal=bigVal, bigMsk=bigMsk)
return big, phf.Size
size = phf.Size
if (
phf.Field.Format == InputFormat.u8
or phf.Field.Format == InputFormat.u16
or phf.Field.Format == InputFormat.u32
):
bigVal = self.appendNum(bigVal, phf.Value, size)
bigMsk = self.appendNum(bigMsk, phf.Mask, size)
elif phf.Field.Format == InputFormat.u64:
bigVal = self.appendUInt64(bigVal, phf.Value, size)
bigMsk = self.appendUInt64(bigMsk, phf.Mask, size)
elif phf.Field.Format == InputFormat.ipv4:
bigVal = self.appendIPv4(bigVal, phf.Value)
bigMsk = self.appendIPv4(bigMsk, phf.Mask)
elif phf.Field.Format == InputFormat.ipv6:
bigVal = self.appendIPv6(bigVal, phf.Value)
bigMsk = self.appendIPv6(bigMsk, phf.Mask)
elif phf.Field.Format == InputFormat.mac:
bigVal = self.appendMAC(bigVal, phf.Value)
bigMsk = self.appendMAC(bigMsk, phf.Mask)
elif phf.Field.Format == InputFormat.bytearray:
bigVal = self.appendByteArray(bigVal, phf.Value, size)
bigMsk = self.appendByteArray(bigMsk, phf.Mask, size)
else:
print("Invalid input format")
big.update(bigVal=bigVal, bigMsk=bigMsk)
return big, size
def Resolve(self):
big = {"bigVal": 0, "bigMsk": 0}
offset = 0
for phf in self.fields:
if phf.Size == 0:
continue
big, bits = self.append(big, phf)
offset += bits
byteList1 = []
byteList2 = []
bigVal = big["bigVal"]
bigMsk = big["bigMsk"]
while offset > 0:
byteList1.append(bigVal & 0xFF)
byteList2.append(bigMsk & 0xFF)
bigVal = bigVal >> 8
bigMsk = bigMsk >> 8
offset -= 8
byteList1.reverse()
byteList2.reverse()
buffer = copy.deepcopy(byteList1)
mask = copy.deepcopy(byteList2)
self.Buffer = buffer
self.Mask = mask